update
This commit is contained in:
0
vllm_old/entrypoints/__init__.py
Normal file
0
vllm_old/entrypoints/__init__.py
Normal file
BIN
vllm_old/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/api_server.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/api_server.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/chat_utils.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/chat_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/constants.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/constants.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/context.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/context.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/dynamic_lora.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/dynamic_lora.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/harmony_utils.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/harmony_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/launcher.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/launcher.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/logger.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/responses_utils.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/responses_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/score_utils.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/score_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/ssl.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/ssl.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/tool.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/tool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/tool_server.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/tool_server.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/__pycache__/utils.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_old/entrypoints/anthropic/__init__.py
Normal file
0
vllm_old/entrypoints/anthropic/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
162
vllm_old/entrypoints/anthropic/protocol.py
Normal file
162
vllm_old/entrypoints/anthropic/protocol.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pydantic models for Anthropic API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class AnthropicError(BaseModel):
|
||||
"""Error structure for Anthropic API"""
|
||||
|
||||
type: str
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(BaseModel):
|
||||
"""Error response structure for Anthropic API"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: AnthropicError
|
||||
|
||||
|
||||
class AnthropicUsage(BaseModel):
|
||||
"""Token usage information"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int | None = None
|
||||
cache_read_input_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicContentBlock(BaseModel):
|
||||
"""Content block in message"""
|
||||
|
||||
type: Literal["text", "image", "tool_use", "tool_result"]
|
||||
text: str | None = None
|
||||
# For image content
|
||||
source: dict[str, Any] | None = None
|
||||
# For tool use/result
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
is_error: bool | None = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
"""Message structure"""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: str | list[AnthropicContentBlock]
|
||||
|
||||
|
||||
class AnthropicTool(BaseModel):
|
||||
"""Tool definition"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
@field_validator("input_schema")
|
||||
@classmethod
|
||||
def validate_input_schema(cls, v):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("input_schema must be a dictionary")
|
||||
if "type" not in v:
|
||||
v["type"] = "object" # Default to object type
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicToolChoice(BaseModel):
|
||||
"""Tool Choice definition"""
|
||||
|
||||
type: Literal["auto", "any", "tool"]
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request"""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
max_tokens: int
|
||||
metadata: dict[str, Any] | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool | None = False
|
||||
system: str | list[AnthropicContentBlock] | None = None
|
||||
temperature: float | None = None
|
||||
tool_choice: AnthropicToolChoice | None = None
|
||||
tools: list[AnthropicTool] | None = None
|
||||
top_k: int | None = None
|
||||
top_p: float | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model is required")
|
||||
return v
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError("max_tokens must be positive")
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicDelta(BaseModel):
|
||||
"""Delta for streaming responses"""
|
||||
|
||||
type: Literal["text_delta", "input_json_delta"] | None = None
|
||||
text: str | None = None
|
||||
partial_json: str | None = None
|
||||
|
||||
# Message delta
|
||||
stop_reason: (
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
|
||||
) = None
|
||||
stop_sequence: str | None = None
|
||||
|
||||
|
||||
class AnthropicStreamEvent(BaseModel):
|
||||
"""Streaming event"""
|
||||
|
||||
type: Literal[
|
||||
"message_start",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"ping",
|
||||
"error",
|
||||
]
|
||||
message: Optional["AnthropicMessagesResponse"] = None
|
||||
delta: AnthropicDelta | None = None
|
||||
content_block: AnthropicContentBlock | None = None
|
||||
index: int | None = None
|
||||
error: AnthropicError | None = None
|
||||
usage: AnthropicUsage | None = None
|
||||
|
||||
|
||||
class AnthropicMessagesResponse(BaseModel):
|
||||
"""Anthropic Messages API response"""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[AnthropicContentBlock]
|
||||
model: str
|
||||
stop_reason: (
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
|
||||
) = None
|
||||
stop_sequence: str | None = None
|
||||
usage: AnthropicUsage | None = None
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if not self.id:
|
||||
self.id = f"msg_{int(time.time() * 1000)}"
|
||||
460
vllm_old/entrypoints/anthropic/serving_messages.py
Normal file
460
vllm_old/entrypoints/anthropic/serving_messages.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py
|
||||
|
||||
"""Anthropic Messages API serving handler"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicDelta,
|
||||
AnthropicError,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicStreamEvent,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam,
|
||||
ErrorResponse,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_data_with_event(data: str, event: str):
|
||||
return f"event: {event}\ndata: {data}\n\n"
|
||||
|
||||
|
||||
class AnthropicServingMessages(OpenAIServingChat):
|
||||
"""Handler for Anthropic Messages API requests"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
response_role: str,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
reasoning_parser: str = "",
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: str | None = None,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
response_role=response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_auto_tools=enable_auto_tools,
|
||||
tool_parser=tool_parser,
|
||||
enable_prompt_tokens_details=enable_prompt_tokens_details,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
self.stop_reason_map = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
}
|
||||
|
||||
def _convert_anthropic_to_openai_request(
|
||||
self, anthropic_request: AnthropicMessagesRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""Convert Anthropic message format to OpenAI format"""
|
||||
openai_messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if anthropic_request.system:
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for msg in anthropic_request.messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
# Handle complex content blocks
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for block in msg.content:
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": block.source.get("data", "")},
|
||||
}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
# Convert tool use to function call format
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif block.type == "tool_result":
|
||||
if msg.role == "user":
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.id or "",
|
||||
"content": str(block.content)
|
||||
if block.content
|
||||
else "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Assistant tool result becomes regular text
|
||||
tool_result_text = (
|
||||
str(block.content) if block.content else ""
|
||||
)
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tool result: {tool_result_text}",
|
||||
}
|
||||
)
|
||||
|
||||
# Add tool calls to the message if any
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
# Add content parts if any
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls:
|
||||
continue
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
max_tokens=anthropic_request.max_tokens,
|
||||
max_completion_tokens=anthropic_request.max_tokens,
|
||||
stop=anthropic_request.stop_sequences,
|
||||
temperature=anthropic_request.temperature,
|
||||
top_p=anthropic_request.top_p,
|
||||
top_k=anthropic_request.top_k,
|
||||
)
|
||||
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate({"include_usage": True})
|
||||
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
elif anthropic_request.tool_choice.type == "auto":
|
||||
req.tool_choice = "auto"
|
||||
elif anthropic_request.tool_choice.type == "any":
|
||||
req.tool_choice = "required"
|
||||
elif anthropic_request.tool_choice.type == "tool":
|
||||
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": anthropic_request.tool_choice.name},
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
if anthropic_request.tools is None:
|
||||
return req
|
||||
for tool in anthropic_request.tools:
|
||||
tools.append(
|
||||
ChatCompletionToolsParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
if req.tool_choice is None:
|
||||
req.tool_choice = "auto"
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
async def create_messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | AnthropicMessagesResponse | ErrorResponse:
|
||||
"""
|
||||
Messages API similar to Anthropic's API.
|
||||
|
||||
See https://docs.anthropic.com/en/api/messages
|
||||
for the API specification. This API mimics the Anthropic messages API.
|
||||
"""
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Received messages request %s", request.model_dump_json())
|
||||
chat_req = self._convert_anthropic_to_openai_request(request)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Convert to OpenAI request %s", chat_req.model_dump_json())
|
||||
generator = await self.create_chat_completion(chat_req, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return generator
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return self.messages_full_converter(generator)
|
||||
|
||||
return self.message_stream_converter(generator)
|
||||
|
||||
def messages_full_converter(
|
||||
self,
|
||||
generator: ChatCompletionResponse,
|
||||
) -> AnthropicMessagesResponse:
|
||||
result = AnthropicMessagesResponse(
|
||||
id=generator.id,
|
||||
content=[],
|
||||
model=generator.model,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=generator.usage.prompt_tokens,
|
||||
output_tokens=generator.usage.completion_tokens,
|
||||
),
|
||||
)
|
||||
if generator.choices[0].finish_reason == "stop":
|
||||
result.stop_reason = "end_turn"
|
||||
elif generator.choices[0].finish_reason == "length":
|
||||
result.stop_reason = "max_tokens"
|
||||
elif generator.choices[0].finish_reason == "tool_calls":
|
||||
result.stop_reason = "tool_use"
|
||||
|
||||
content: list[AnthropicContentBlock] = [
|
||||
AnthropicContentBlock(
|
||||
type="text",
|
||||
text=generator.choices[0].message.content
|
||||
if generator.choices[0].message.content
|
||||
else "",
|
||||
)
|
||||
]
|
||||
|
||||
for tool_call in generator.choices[0].message.tool_calls:
|
||||
anthropic_tool_call = AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
content += [anthropic_tool_call]
|
||||
|
||||
result.content = content
|
||||
|
||||
return result
|
||||
|
||||
async def message_stream_converter(
|
||||
self,
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
first_item = True
|
||||
finish_reason = None
|
||||
content_block_index = 0
|
||||
content_block_started = False
|
||||
|
||||
async for item in generator:
|
||||
if item.startswith("data:"):
|
||||
data_str = item[5:].strip().rstrip("\n")
|
||||
if data_str == "[DONE]":
|
||||
stop_message = AnthropicStreamEvent(
|
||||
type="message_stop",
|
||||
)
|
||||
data = stop_message.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
yield wrap_data_with_event(data, "message_stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
else:
|
||||
origin_chunk = ChatCompletionStreamResponse.model_validate_json(
|
||||
data_str
|
||||
)
|
||||
|
||||
if first_item:
|
||||
chunk = AnthropicStreamEvent(
|
||||
type="message_start",
|
||||
message=AnthropicMessagesResponse(
|
||||
id=origin_chunk.id,
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
),
|
||||
)
|
||||
first_item = False
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "message_start")
|
||||
continue
|
||||
|
||||
# last chunk including usage info
|
||||
if len(origin_chunk.choices) == 0:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_stop")
|
||||
stop_reason = self.stop_reason_map.get(
|
||||
finish_reason or "stop"
|
||||
)
|
||||
chunk = AnthropicStreamEvent(
|
||||
type="message_delta",
|
||||
delta=AnthropicDelta(stop_reason=stop_reason),
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
output_tokens=origin_chunk.usage.completion_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "message_delta")
|
||||
continue
|
||||
|
||||
if origin_chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = origin_chunk.choices[0].finish_reason
|
||||
continue
|
||||
|
||||
# content
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if not content_block_started:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="text", text=""
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
|
||||
if origin_chunk.choices[0].delta.content == "":
|
||||
continue
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="text_delta",
|
||||
text=origin_chunk.choices[0].delta.content,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
|
||||
# tool calls
|
||||
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.id is not None:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(
|
||||
exclude_unset=True
|
||||
)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_stop"
|
||||
)
|
||||
content_block_started = False
|
||||
content_block_index += 1
|
||||
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name
|
||||
if tool_call.function
|
||||
else None,
|
||||
input={},
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
|
||||
else:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments
|
||||
if tool_call.function
|
||||
else None,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
else:
|
||||
error_response = AnthropicStreamEvent(
|
||||
type="error",
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message="Invalid data format received",
|
||||
),
|
||||
)
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in message stream converter.")
|
||||
error_response = AnthropicStreamEvent(
|
||||
type="error",
|
||||
error=AnthropicError(type="internal_error", message=str(e)),
|
||||
)
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
184
vllm_old/entrypoints/api_server.py
Normal file
184
vllm_old/entrypoints/api_server.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
||||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app
|
||||
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
args: Namespace,
|
||||
llm_engine: AsyncLLMEngine | None = None,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (
|
||||
llm_engine
|
||||
if llm_engine is not None
|
||||
else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER
|
||||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(
|
||||
args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
|
||||
) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
set_ulimit()
|
||||
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy",
|
||||
)
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
1690
vllm_old/entrypoints/chat_utils.py
Normal file
1690
vllm_old/entrypoints/chat_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
13
vllm_old/entrypoints/cli/__init__.py
Normal file
13
vllm_old/entrypoints/cli/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
|
||||
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
|
||||
|
||||
__all__: list[str] = [
|
||||
"BenchmarkLatencySubcommand",
|
||||
"BenchmarkServingSubcommand",
|
||||
"BenchmarkSweepSubcommand",
|
||||
"BenchmarkThroughputSubcommand",
|
||||
]
|
||||
BIN
vllm_old/entrypoints/cli/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/collect_env.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/collect_env.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/main.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/openai.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/openai.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/run_batch.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/run_batch.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/serve.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/serve.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_old/entrypoints/cli/__pycache__/types.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/cli/__pycache__/types.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_old/entrypoints/cli/benchmark/__init__.py
Normal file
0
vllm_old/entrypoints/cli/benchmark/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
25
vllm_old/entrypoints/cli/benchmark/base.py
Normal file
25
vllm_old/entrypoints/cli/benchmark/base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkSubcommandBase(CLISubcommand):
|
||||
"""The base class of subcommands for `vllm bench`."""
|
||||
|
||||
help: str
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
"""Add the CLI arguments to the parser."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
args: The arguments to the command.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
21
vllm_old/entrypoints/cli/benchmark/latency.py
Normal file
21
vllm_old/entrypoints/cli/benchmark/latency.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.latency import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
|
||||
"""The `latency` subcommand for `vllm bench`."""
|
||||
|
||||
name = "latency"
|
||||
help = "Benchmark the latency of a single batch of requests."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
56
vllm_old/entrypoints/cli/benchmark/main.py
Normal file
56
vllm_old/entrypoints/cli/benchmark/main.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
|
||||
class BenchmarkSubcommand(CLISubcommand):
|
||||
"""The `bench` subcommand for the vLLM CLI."""
|
||||
|
||||
name = "bench"
|
||||
help = "vLLM bench subcommand."
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
args.dispatch_function(args)
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
pass
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
bench_parser = subparsers.add_parser(
|
||||
self.name,
|
||||
description=self.help,
|
||||
usage=f"vllm {self.name} <bench_type> [options]",
|
||||
)
|
||||
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
|
||||
|
||||
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
|
||||
cmd_subparser = bench_subparsers.add_parser(
|
||||
cmd_cls.name,
|
||||
help=cmd_cls.help,
|
||||
description=cmd_cls.help,
|
||||
usage=f"vllm {self.name} {cmd_cls.name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd)
|
||||
cmd_cls.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"{self.name} {cmd_cls.name}"
|
||||
)
|
||||
return bench_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [BenchmarkSubcommand()]
|
||||
21
vllm_old/entrypoints/cli/benchmark/serve.py
Normal file
21
vllm_old/entrypoints/cli/benchmark/serve.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.serve import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `serve` subcommand for `vllm bench`."""
|
||||
|
||||
name = "serve"
|
||||
help = "Benchmark the online serving throughput."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
21
vllm_old/entrypoints/cli/benchmark/sweep.py
Normal file
21
vllm_old/entrypoints/cli/benchmark/sweep.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.sweep.cli import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkSweepSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `sweep` subcommand for `vllm bench`."""
|
||||
|
||||
name = "sweep"
|
||||
help = "Benchmark for a parameter sweep."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
21
vllm_old/entrypoints/cli/benchmark/throughput.py
Normal file
21
vllm_old/entrypoints/cli/benchmark/throughput.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.throughput import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `throughput` subcommand for `vllm bench`."""
|
||||
|
||||
name = "throughput"
|
||||
help = "Benchmark offline inference throughput."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
38
vllm_old/entrypoints/cli/collect_env.py
Normal file
38
vllm_old/entrypoints/cli/collect_env.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
from vllm.collect_env import main as collect_env_main
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
|
||||
class CollectEnvSubcommand(CLISubcommand):
|
||||
"""The `collect-env` subcommand for the vLLM CLI."""
|
||||
|
||||
name = "collect-env"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Collect information about the environment."""
|
||||
collect_env_main()
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
return subparsers.add_parser(
|
||||
"collect-env",
|
||||
help="Start collecting environment information.",
|
||||
description="Start collecting environment information.",
|
||||
usage="vllm collect-env",
|
||||
)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [CollectEnvSubcommand()]
|
||||
79
vllm_old/entrypoints/cli/main.py
Normal file
79
vllm_old/entrypoints/cli/main.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The CLI entrypoints of vLLM
|
||||
|
||||
Note that all future modules must be lazily loaded within main
|
||||
to avoid certain eager import breakage."""
|
||||
|
||||
import importlib.metadata
|
||||
import sys
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.collect_env
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.run_batch
|
||||
import vllm.entrypoints.cli.serve
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.collect_env,
|
||||
vllm.entrypoints.cli.run_batch,
|
||||
]
|
||||
|
||||
cli_env_setup()
|
||||
|
||||
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "bench":
|
||||
logger.debug(
|
||||
"Bench command detected, must ensure current platform is not "
|
||||
"UnspecifiedPlatform to avoid device type inference error"
|
||||
)
|
||||
from vllm import platforms
|
||||
|
||||
if platforms.current_platform.is_unspecified():
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
|
||||
platforms.current_platform = CpuPlatform()
|
||||
logger.info(
|
||||
"Unspecified platform detected, switching to CPU Platform instead."
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM CLI",
|
||||
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
action="version",
|
||||
version=importlib.metadata.version("vllm"),
|
||||
)
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
cmds = {}
|
||||
for cmd_module in CMD_MODULES:
|
||||
new_cmds = cmd_module.cmd_init()
|
||||
for cmd in new_cmds:
|
||||
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
|
||||
cmds[cmd.name] = cmd
|
||||
args = parser.parse_args()
|
||||
if args.subparser in cmds:
|
||||
cmds[args.subparser].validate(args)
|
||||
|
||||
if hasattr(args, "dispatch_function"):
|
||||
args.dispatch_function(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
256
vllm_old/entrypoints/cli/openai.py
Normal file
256
vllm_old/entrypoints/cli/openai.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
|
||||
def _register_signal_handlers():
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTSTP, signal_handler)
|
||||
|
||||
|
||||
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
|
||||
_register_signal_handlers()
|
||||
|
||||
base_url = args.url
|
||||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
|
||||
openai_client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
if args.model_name:
|
||||
model_name = args.model_name
|
||||
else:
|
||||
available_models = openai_client.models.list()
|
||||
model_name = available_models.data[0].id
|
||||
|
||||
print(f"Using model: {model_name}")
|
||||
|
||||
return model_name, openai_client
|
||||
|
||||
|
||||
def _print_chat_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
output += delta.content
|
||||
print(delta.content, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def _print_completion_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
text = chunk.choices[0].text
|
||||
if text is not None:
|
||||
output += text
|
||||
print(text, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=model_name, messages=conversation, stream=True
|
||||
)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
|
||||
def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="url of the running OpenAI-Compatible RESTful API server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The model name used in prompt completion, default to "
|
||||
"the first model in list models API call."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"API key for OpenAI services. If provided, this api key "
|
||||
"will overwrite the api key obtained through environment variables."
|
||||
),
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class ChatCommand(CLISubcommand):
|
||||
"""The `chat` subcommand for the vLLM CLI."""
|
||||
|
||||
name = "chat"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
system_prompt = args.system_prompt
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
if args.quick:
|
||||
conversation.append({"role": "user", "content": args.quick})
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=model_name, messages=conversation, stream=True
|
||||
)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
return
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
try:
|
||||
input_message = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=model_name, messages=conversation, stream=True
|
||||
)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Add CLI arguments for the chat command."""
|
||||
_add_query_options(parser)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The system prompt to be added to the chat template, "
|
||||
"used for models that support system prompts."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="MESSAGE",
|
||||
help=("Send a single prompt as MESSAGE and print the response, then exit."),
|
||||
)
|
||||
return parser
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
parser = subparsers.add_parser(
|
||||
"chat",
|
||||
help="Generate chat completions via the running API server.",
|
||||
description="Generate chat completions via the running API server.",
|
||||
usage="vllm chat [options]",
|
||||
)
|
||||
return ChatCommand.add_cli_args(parser)
|
||||
|
||||
|
||||
class CompleteCommand(CLISubcommand):
|
||||
"""The `complete` subcommand for the vLLM CLI."""
|
||||
|
||||
name = "complete"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
model_name, client = _interactive_cli(args)
|
||||
|
||||
kwargs = {
|
||||
"model": model_name,
|
||||
"stream": True,
|
||||
}
|
||||
if args.max_tokens:
|
||||
kwargs["max_tokens"] = args.max_tokens
|
||||
|
||||
if args.quick:
|
||||
stream = client.completions.create(prompt=args.quick, **kwargs)
|
||||
_print_completion_stream(stream)
|
||||
return
|
||||
|
||||
print("Please enter prompt to complete:")
|
||||
while True:
|
||||
try:
|
||||
input_prompt = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
stream = client.completions.create(prompt=input_prompt, **kwargs)
|
||||
_print_completion_stream(stream)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Add CLI arguments for the complete command."""
|
||||
_add_query_options(parser)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
help="Maximum number of tokens to generate per output sequence.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quick",
|
||||
type=str,
|
||||
metavar="PROMPT",
|
||||
help="Send a single prompt and print the completion output, then exit.",
|
||||
)
|
||||
return parser
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
parser = subparsers.add_parser(
|
||||
"complete",
|
||||
help=(
|
||||
"Generate text completions based on the given prompt "
|
||||
"via the running API server."
|
||||
),
|
||||
description=(
|
||||
"Generate text completions based on the given prompt "
|
||||
"via the running API server."
|
||||
),
|
||||
usage="vllm complete [options]",
|
||||
)
|
||||
return CompleteCommand.add_cli_args(parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ChatCommand(), CompleteCommand()]
|
||||
68
vllm_old/entrypoints/cli/run_batch.py
Normal file
68
vllm_old/entrypoints/cli/run_batch.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib.metadata
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RunBatchSubcommand(CLISubcommand):
|
||||
"""The `run-batch` subcommand for vLLM CLI."""
|
||||
|
||||
name = "run-batch"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
from vllm.entrypoints.openai.run_batch import main as run_batch_main
|
||||
|
||||
logger.info(
|
||||
"vLLM batch processing API version %s", importlib.metadata.version("vllm")
|
||||
)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server.
|
||||
# LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(run_batch_main(args))
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
from vllm.entrypoints.openai.run_batch import make_arg_parser
|
||||
|
||||
run_batch_parser = subparsers.add_parser(
|
||||
self.name,
|
||||
help="Run batch prompts and write results to file.",
|
||||
description=(
|
||||
"Run batch prompts using vLLM's OpenAI-compatible API.\n"
|
||||
"Supports local or HTTP input/output files."
|
||||
),
|
||||
usage="vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
|
||||
)
|
||||
run_batch_parser = make_arg_parser(run_batch_parser)
|
||||
run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
|
||||
return run_batch_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [RunBatchSubcommand()]
|
||||
249
vllm_old/entrypoints/cli/serve.py
Normal file
249
vllm_old/entrypoints/cli/serve.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
|
||||
import uvloop
|
||||
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
run_server,
|
||||
run_server_worker,
|
||||
setup_server,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
|
||||
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.
|
||||
|
||||
Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
|
||||
--help=ModelConfig, --help=Frontend)
|
||||
Use `--help=all` to show all available flags at once.
|
||||
"""
|
||||
|
||||
|
||||
class ServeSubcommand(CLISubcommand):
|
||||
"""The `serve` subcommand for the vLLM CLI."""
|
||||
|
||||
name = "serve"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
# If model is specified in CLI (as positional arg), it takes precedence
|
||||
if hasattr(args, "model_tag") and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
else:
|
||||
if args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
serve_parser = subparsers.add_parser(
|
||||
self.name, description=DESCRIPTION, usage="vllm serve [model_tag] [options]"
|
||||
)
|
||||
|
||||
serve_parser = make_arg_parser(serve_parser)
|
||||
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
|
||||
return serve_parser
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ServeSubcommand()]
|
||||
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
if args.api_server_count > 1:
|
||||
raise ValueError("api_server_count can't be set in headless mode")
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=usage_context, headless=True
|
||||
)
|
||||
|
||||
if engine_args.data_parallel_hybrid_lb:
|
||||
raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
logger.debug("Received %d signal.", signum)
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
if parallel_config.node_rank_within_dp > 0:
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
# Run headless workers (for multi-node PP/TP).
|
||||
host = parallel_config.master_addr
|
||||
head_node_address = f"{host}:{parallel_config.master_port}"
|
||||
logger.info(
|
||||
"Launching vLLM (v%s) headless multiproc executor, "
|
||||
"with head node address %s for torch.distributed process group.",
|
||||
VLLM_VERSION,
|
||||
head_node_address,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
|
||||
executor.start_worker_monitor(inline=True)
|
||||
return
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.data_parallel_rpc_port
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.",
|
||||
local_engine_count,
|
||||
handshake_address,
|
||||
)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=vllm_config.parallel_config.data_parallel_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
local_client=False,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
|
||||
try:
|
||||
engine_manager.join_first()
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
assert not args.headless
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
engine_args._api_process_count = num_api_servers
|
||||
engine_args._api_process_rank = -1
|
||||
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
raise ValueError(
|
||||
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
|
||||
)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb
|
||||
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
|
||||
|
||||
api_server_manager: APIServerProcessManager | None = None
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=addresses.inputs,
|
||||
output_addresses=addresses.outputs,
|
||||
stats_update_address=coordinator.get_stats_publish_address()
|
||||
if coordinator
|
||||
else None,
|
||||
)
|
||||
|
||||
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
|
||||
# start of the API servers until the local engine is started
|
||||
# (after the launcher context manager exits),
|
||||
# since we get the front-end stats update address from the coordinator
|
||||
# via the handshake with the local engine.
|
||||
if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb):
|
||||
# Start API servers using the manager.
|
||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
||||
|
||||
# Start API servers now if they weren't already started.
|
||||
if api_server_manager is None:
|
||||
api_server_manager_kwargs["stats_update_address"] = (
|
||||
addresses.frontend_stats_publish_address
|
||||
)
|
||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
|
||||
)
|
||||
29
vllm_old/entrypoints/cli/types.py
Normal file
29
vllm_old/entrypoints/cli/types.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
|
||||
class CLISubcommand:
|
||||
"""Base class for CLI argument handlers."""
|
||||
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
# No validation by default
|
||||
pass
|
||||
|
||||
def subparser_init(
|
||||
self, subparsers: argparse._SubParsersAction
|
||||
) -> FlexibleArgumentParser:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
10
vllm_old/entrypoints/constants.py
Normal file
10
vllm_old/entrypoints/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared constants for vLLM entrypoints.
|
||||
"""
|
||||
|
||||
# HTTP header limits for h11 parser
|
||||
# These constants help mitigate header abuse attacks
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
||||
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
||||
572
vllm_old/entrypoints/context.py
Normal file
572
vllm_old/entrypoints/context.py
Normal file
@@ -0,0 +1,572 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}"
|
||||
)
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnMetrics:
|
||||
"""Tracks token and toolcall details for a single conversation turn."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cached_input_tokens=0,
|
||||
tool_output_tokens=0,
|
||||
):
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.cached_input_tokens = cached_input_tokens
|
||||
self.tool_output_tokens = tool_output_tokens
|
||||
|
||||
def reset(self):
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_input_tokens = 0
|
||||
self.tool_output_tokens = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnMetrics(
|
||||
self.input_tokens,
|
||||
self.output_tokens,
|
||||
self.cached_input_tokens,
|
||||
self.tool_output_tokens,
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
@abstractmethod
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(self) -> list[Message]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
def _create_json_parse_error_messages(
|
||||
last_msg: Message, e: json.JSONDecodeError
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Creates an error message when json parse failed.
|
||||
"""
|
||||
error_msg = (
|
||||
f"Error parsing tool arguments as JSON: {str(e)}. "
|
||||
"Please ensure the tool call arguments are valid JSON and try again."
|
||||
)
|
||||
content = TextContent(text=error_msg)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# todo num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
raise ValueError("SimpleContext only supports RequestOutput.")
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.finish_reason: str | None = None
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn_metrics = TurnMetrics()
|
||||
# Track metrics for all turns
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
def _update_num_reasoning_tokens(self):
|
||||
# Count all analysis and commentary channels as reasoning tokens
|
||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
self._update_decode_token_usage(output)
|
||||
# Append current turn to all turn list for next turn's calculations
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||
"""Update token usage statistics for the prefill phase of generation.
|
||||
|
||||
The prefill phase processes the input prompt tokens. This method:
|
||||
1. Counts the prompt tokens for this turn
|
||||
2. Calculates tool output tokens for multi-turn conversations
|
||||
3. Updates cached token counts
|
||||
4. Tracks state for next turn calculations
|
||||
|
||||
Tool output tokens are calculated as:
|
||||
current_prompt_tokens - last_turn_prompt_tokens -
|
||||
last_turn_output_tokens
|
||||
This represents tokens added between turns (typically tool responses).
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing prompt token information
|
||||
"""
|
||||
if output.prompt_token_ids is not None:
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error("RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn_metrics.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
previous_turn = self.all_turn_metrics[-1]
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (
|
||||
self.current_turn_metrics.input_tokens
|
||||
- previous_turn.input_tokens
|
||||
- previous_turn.output_tokens
|
||||
)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
if this_turn_tool_tokens < 0:
|
||||
logger.error(
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens,
|
||||
self.current_turn_metrics.input_tokens,
|
||||
previous_turn.input_tokens,
|
||||
previous_turn.output_tokens,
|
||||
)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
num_cached_token = output.num_cached_tokens
|
||||
if num_cached_token is not None:
|
||||
self.num_cached_tokens += num_cached_token
|
||||
self.current_turn_metrics.cached_input_tokens = num_cached_token
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
|
||||
The decode phase processes the generated output tokens. This method:
|
||||
1. Counts output tokens from all completion outputs
|
||||
2. Updates the total output token count
|
||||
3. Tracks tokens generated in the current turn
|
||||
|
||||
In streaming mode, this is called for each token generated.
|
||||
In non-streaming mode, this is called once with all output tokens.
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing generated token information
|
||||
|
||||
Returns:
|
||||
int: Number of output tokens processed in this call
|
||||
"""
|
||||
updated_output_token_count = 0
|
||||
if output.outputs:
|
||||
for completion_output in output.outputs:
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn_metrics.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (
|
||||
recipient.startswith("browser.")
|
||||
or recipient.startswith("python")
|
||||
or recipient.startswith("container.")
|
||||
)
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
return []
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self._tool_sessions["browser"], last_msg
|
||||
)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg
|
||||
)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||
try:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
except json.JSONDecodeError as e:
|
||||
return _create_json_parse_error_messages(last_msg, e)
|
||||
else:
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
self.first_tok_of_message = True
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
# beginning of a new message
|
||||
self.first_tok_of_message = output.finished
|
||||
for tok in output.outputs[0].token_ids:
|
||||
self.parser.process(tok)
|
||||
self._update_decode_token_usage(output)
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||
)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
# Sometimes the recipient is not set for tool messages,
|
||||
# so we set it to "assistant"
|
||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||
msg.recipient = "assistant"
|
||||
toks = self.encoding.render(msg)
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
# TODO: add tool_output messages to self._messages
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
# `<|start|>assistant`,
|
||||
# we need to process them in parser.
|
||||
rendered_tokens = super().render_for_completion()
|
||||
|
||||
last_n = -1
|
||||
to_process = []
|
||||
while rendered_tokens[last_n] != self.last_tok:
|
||||
to_process.append(rendered_tokens[last_n])
|
||||
last_n -= 1
|
||||
for tok in reversed(to_process):
|
||||
self.parser.process(tok)
|
||||
|
||||
return rendered_tokens
|
||||
57
vllm_old/entrypoints/dynamic_lora.py
Normal file
57
vllm_old/entrypoints/dynamic_lora.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def register_dynamic_lora_routes(router: APIRouter):
|
||||
@sagemaker_standards.register_load_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "body.name",
|
||||
"lora_path": "body.src",
|
||||
},
|
||||
)
|
||||
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
|
||||
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
|
||||
handler: OpenAIServingModels = models(raw_request)
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@sagemaker_standards.register_unload_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "path_params.adapter_name",
|
||||
}
|
||||
)
|
||||
@router.post(
|
||||
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def unload_lora_adapter(
|
||||
request: UnloadLoRAAdapterRequest, raw_request: Request
|
||||
):
|
||||
handler: OpenAIServingModels = models(raw_request)
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
return router
|
||||
535
vllm_old/entrypoints/harmony_utils.py
Normal file
535
vllm_old/entrypoints/harmony_utils.py
Normal file
@@ -0,0 +1,535 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind,
|
||||
ActionOpenPage,
|
||||
ActionSearch,
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (
|
||||
Author,
|
||||
ChannelConfig,
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
ReasoningEffort,
|
||||
Role,
|
||||
StreamableParser,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
ToolDescription,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
from openai_harmony import Role as OpenAIHarmonyRole
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
# Builtin tools that should be included in the system message when
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
MCP_BUILTIN_TOOLS: set[str] = {
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
}
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: set[str]) -> bool:
|
||||
"""
|
||||
Checks if the given tool types are custom tools
|
||||
(i.e. any tool other than MCP buildin tools)
|
||||
"""
|
||||
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
if _harmony_encoding is None:
|
||||
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return _harmony_encoding
|
||||
|
||||
|
||||
def get_system_message(
|
||||
model_identity: str | None = None,
|
||||
reasoning_effort: Literal["high", "medium", "low"] | None = None,
|
||||
start_date: str | None = None,
|
||||
browser_description: str | None = None,
|
||||
python_description: str | None = None,
|
||||
container_description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
with_custom_tools: bool = False,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
||||
current_identity = sys_msg_content.model_identity
|
||||
new_identity = (
|
||||
f"{current_identity}\n{instructions}" if current_identity else instructions
|
||||
)
|
||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort]
|
||||
)
|
||||
if start_date is None:
|
||||
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
|
||||
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
if container_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
||||
if not with_custom_tools:
|
||||
channel_config = sys_msg_content.channel_config
|
||||
invalid_channel = "commentary"
|
||||
new_config = ChannelConfig.require_channels(
|
||||
[c for c in channel_config.valid_channels if c != invalid_channel]
|
||||
)
|
||||
sys_msg_content = sys_msg_content.with_channel_config(new_config)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
|
||||
def create_tool_definition(tool: ChatCompletionToolsParam | Tool):
|
||||
if isinstance(tool, ChatCompletionToolsParam):
|
||||
return ToolDescription.new(
|
||||
name=tool.function.name,
|
||||
description=tool.function.description,
|
||||
parameters=tool.function.parameters,
|
||||
)
|
||||
return ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
)
|
||||
|
||||
|
||||
def get_developer_message(
|
||||
instructions: str | None = None,
|
||||
tools: list[Tool | ChatCompletionToolsParam] | None = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools: list[Tool | ChatCompletionToolsParam] = []
|
||||
for tool in tools:
|
||||
if tool.type in (
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
"mcp",
|
||||
):
|
||||
# These are built-in tools that are added to the system message.
|
||||
# Adding in MCP for now until we support MCP tools executed
|
||||
# server side
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
create_tool_definition(tool) for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions
|
||||
)
|
||||
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||
return dev_msg
|
||||
|
||||
|
||||
def get_user_message(content: str) -> Message:
|
||||
return Message.from_role_and_content(Role.USER, content)
|
||||
|
||||
|
||||
def parse_response_input(
|
||||
response_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
|
||||
) -> Message:
|
||||
if not isinstance(response_msg, dict):
|
||||
response_msg = response_msg.model_dump()
|
||||
if "type" not in response_msg or response_msg["type"] == "message":
|
||||
role = response_msg["role"]
|
||||
content = response_msg["content"]
|
||||
if role == "system":
|
||||
# User is trying to set a system message. Change it to:
|
||||
# <|start|>developer<|message|># Instructions
|
||||
# {instructions}<|end|>
|
||||
role = "developer"
|
||||
text_prefix = "Instructions:\n"
|
||||
else:
|
||||
text_prefix = ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
else:
|
||||
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: ResponseFunctionToolCall | None = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if (
|
||||
isinstance(prev_response, ResponseFunctionToolCall)
|
||||
and prev_response.call_id == call_id
|
||||
):
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
response_msg["output"],
|
||||
)
|
||||
elif response_msg["type"] == "reasoning":
|
||||
content = response_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif response_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
if role == "assistant" and tool_calls:
|
||||
msgs: list[Message] = []
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
if isinstance(content, list):
|
||||
# Handle array format for tool message content
|
||||
# by concatenating all text parts.
|
||||
content = "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
).with_channel("commentary")
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return [msg]
|
||||
|
||||
|
||||
def construct_harmony_previous_input_messages(
|
||||
request: ResponsesRequest,
|
||||
) -> list[OpenAIHarmonyMessage]:
|
||||
messages: list[OpenAIHarmonyMessage] = []
|
||||
if request.previous_input_messages:
|
||||
for message in request.previous_input_messages:
|
||||
# Handle both OpenAIHarmonyMessage objects and dictionary inputs
|
||||
if isinstance(message, OpenAIHarmonyMessage):
|
||||
message_role = message.author.role
|
||||
# To match OpenAI, instructions, reasoning and tools are
|
||||
# always taken from the most recent Responses API request
|
||||
# not carried over from previous requests
|
||||
if (
|
||||
message_role == OpenAIHarmonyRole.SYSTEM
|
||||
or message_role == OpenAIHarmonyRole.DEVELOPER
|
||||
):
|
||||
continue
|
||||
messages.append(message)
|
||||
else:
|
||||
harmony_messages = parse_input_to_harmony_message(message)
|
||||
for harmony_msg in harmony_messages:
|
||||
message_role = harmony_msg.author.role
|
||||
# To match OpenAI, instructions, reasoning and tools are
|
||||
# always taken from the most recent Responses API request
|
||||
# not carried over from previous requests
|
||||
if (
|
||||
message_role == OpenAIHarmonyRole.SYSTEM
|
||||
or message_role == OpenAIHarmonyRole.DEVELOPER
|
||||
):
|
||||
continue
|
||||
messages.append(harmony_msg)
|
||||
return messages
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
token_ids = get_encoding().render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
"""
|
||||
Parse a Harmony message into a list of output response items.
|
||||
"""
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
if recipient is not None and recipient.startswith("browser."):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
# We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
|
||||
# env variable since if it is not set, we are certain the json is valid
|
||||
# The use of Actions for web search will be removed entirely in
|
||||
# the future, so this is only necessary temporarily
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
# If the content is not valid JSON, then it was
|
||||
# caught and retried by vLLM, which means we
|
||||
# need to make note of that so the user is aware
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
browser_call = {
|
||||
"query": json_retry_output_message,
|
||||
"url": json_retry_output_message,
|
||||
"pattern": json_retry_output_message,
|
||||
}
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call["pattern"],
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
web_search_item = ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
output_items.append(web_search_item)
|
||||
elif message.channel == "analysis":
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
elif message.channel == "commentary":
|
||||
if recipient is not None and recipient.startswith("functions."):
|
||||
function_name = recipient.split(".")[-1]
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
elif recipient is not None and (
|
||||
recipient.startswith("python")
|
||||
or recipient.startswith("browser")
|
||||
or recipient.startswith("container")
|
||||
):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown recipient: {recipient}")
|
||||
elif message.channel == "final":
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
output_items.append(text_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
return output_items
|
||||
|
||||
|
||||
def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
|
||||
if not parser.current_content:
|
||||
return []
|
||||
if parser.current_role != Role.ASSISTANT:
|
||||
return []
|
||||
current_recipient = parser.current_recipient
|
||||
if current_recipient is not None and current_recipient.startswith("browser."):
|
||||
return []
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
return [reasoning_item]
|
||||
elif parser.current_channel == "final":
|
||||
output_text = ResponseOutputText(
|
||||
text=parser.current_content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
role="assistant",
|
||||
# if the parser still has messages (ie if the generator got cut
|
||||
# abruptly), this should be incomplete
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
return [text_item]
|
||||
return []
|
||||
|
||||
|
||||
def get_stop_tokens_for_assistant_actions() -> list[int]:
|
||||
return get_encoding().stop_tokens_for_assistant_actions()
|
||||
|
||||
|
||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
parser = get_streamable_parser_for_assistant()
|
||||
for token_id in token_ids:
|
||||
parser.process(token_id)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_chat_output(
|
||||
token_ids: Sequence[int],
|
||||
) -> tuple[str | None, str | None, bool]:
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
is_tool_call = False # TODO: update this when tool call is supported
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
reasoning = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
reasoning = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
reasoning_msg = output_msgs[:-1]
|
||||
final_msg = output_msgs[-1]
|
||||
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
|
||||
final_content = final_msg.content[0].text
|
||||
return reasoning, final_content, is_tool_call
|
||||
175
vllm_old/entrypoints/launcher.py
Normal file
175
vllm_old/entrypoints/launcher.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import socket
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.constants import (
|
||||
H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
|
||||
)
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import find_process_using_port
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(
|
||||
app: FastAPI,
|
||||
sock: socket.socket | None,
|
||||
enable_ssl_refresh: bool = False,
|
||||
**uvicorn_kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Start a FastAPI app using Uvicorn, with support for custom Uvicorn config
|
||||
options. Supports http header limits via h11_max_incomplete_event_size and
|
||||
h11_max_header_count.
|
||||
"""
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
path = getattr(route, "path", None)
|
||||
|
||||
if methods is None or path is None:
|
||||
continue
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ", ".join(methods))
|
||||
|
||||
# Extract header limit options if present
|
||||
h11_max_incomplete_event_size = uvicorn_kwargs.pop(
|
||||
"h11_max_incomplete_event_size", None
|
||||
)
|
||||
h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None)
|
||||
|
||||
# Set safe defaults if not provided
|
||||
if h11_max_incomplete_event_size is None:
|
||||
h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
if h11_max_header_count is None:
|
||||
h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
# Set header limits
|
||||
config.h11_max_incomplete_event_size = h11_max_incomplete_event_size
|
||||
config.h11_max_header_count = h11_max_header_count
|
||||
config.load()
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
watchdog_task = loop.create_task(watchdog_loop(server, app.state.engine_client))
|
||||
server_task = loop.create_task(server.serve(sockets=[sock] if sock else None))
|
||||
|
||||
ssl_cert_refresher = (
|
||||
None
|
||||
if not enable_ssl_refresh
|
||||
else SSLCertRefresher(
|
||||
ssl_context=config.ssl,
|
||||
key_path=config.ssl_keyfile,
|
||||
cert_path=config.ssl_certfile,
|
||||
ca_path=config.ssl_ca_certs,
|
||||
)
|
||||
)
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
return dummy_shutdown()
|
||||
except asyncio.CancelledError:
|
||||
port = uvicorn_kwargs["port"]
|
||||
process = find_process_using_port(port)
|
||||
if process is not None:
|
||||
logger.warning(
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port,
|
||||
process,
|
||||
" ".join(process.cmdline()),
|
||||
)
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
finally:
|
||||
watchdog_task.cancel()
|
||||
|
||||
|
||||
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
# Watchdog task that runs in the background, checking
|
||||
# for error state in the engine. Needed to trigger shutdown
|
||||
# if an exception arises is StreamingResponse() generator.
|
||||
"""
|
||||
VLLM_WATCHDOG_TIME_S = 5.0
|
||||
while True:
|
||||
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
|
||||
terminate_if_errored(server, engine)
|
||||
|
||||
|
||||
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
|
||||
"""
|
||||
See discussions here on shutting down a uvicorn server
|
||||
https://github.com/encode/uvicorn/discussions/1103
|
||||
In this case we cannot await the server shutdown here
|
||||
because handler must first return to close the connection
|
||||
for this request.
|
||||
"""
|
||||
engine_errored = engine.errored and not engine.is_running
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
|
||||
server.should_exit = True
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""
|
||||
VLLM V1 AsyncLLM catches exceptions and returns
|
||||
only two types: EngineGenerateError and EngineDeadError.
|
||||
|
||||
EngineGenerateError is raised by the per request generate()
|
||||
method. This error could be request specific (and therefore
|
||||
recoverable - e.g. if there is an error in input processing).
|
||||
|
||||
EngineDeadError is raised by the background output_handler
|
||||
method. This error is global and therefore not recoverable.
|
||||
|
||||
We register these @app.exception_handlers to return nice
|
||||
responses to the end user if they occur and shut down if needed.
|
||||
See https://fastapi.tiangolo.com/tutorial/handling-errors/
|
||||
for more details on how exception handlers work.
|
||||
|
||||
If an exception is encountered in a StreamingResponse
|
||||
generator, the exception is not raised, since we already sent
|
||||
a 200 status. Rather, we send an error message as the next chunk.
|
||||
Since the exception is not raised, this means that the server
|
||||
will not automatically shut down. Instead, we use the watchdog
|
||||
background task for check for errored state.
|
||||
"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
terminate_if_errored(
|
||||
server=server,
|
||||
engine=request.app.state.engine_client,
|
||||
)
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
1768
vllm_old/entrypoints/llm.py
Normal file
1768
vllm_old/entrypoints/llm.py
Normal file
File diff suppressed because it is too large
Load Diff
84
vllm_old/entrypoints/logger.py
Normal file
84
vllm_old/entrypoints/logger.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RequestLogger:
|
||||
def __init__(self, *, max_log_len: int | None) -> None:
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None,
|
||||
prompt_embeds: torch.Tensor | None,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if prompt is not None:
|
||||
prompt = prompt[:max_log_len]
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = prompt_token_ids[:max_log_len]
|
||||
|
||||
logger.debug(
|
||||
"Request %s details: prompt: %r, "
|
||||
"prompt_token_ids: %s, "
|
||||
"prompt_embeds shape: %s.",
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Received request %s: params: %s, lora_request: %s.",
|
||||
request_id,
|
||||
params,
|
||||
lora_request,
|
||||
)
|
||||
|
||||
def log_outputs(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: str,
|
||||
output_token_ids: Sequence[int] | None,
|
||||
finish_reason: str | None = None,
|
||||
is_streaming: bool = False,
|
||||
delta: bool = False,
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if outputs is not None:
|
||||
outputs = outputs[:max_log_len]
|
||||
|
||||
if output_token_ids is not None:
|
||||
# Convert to list and apply truncation
|
||||
output_token_ids = list(output_token_ids)[:max_log_len]
|
||||
|
||||
stream_info = ""
|
||||
if is_streaming:
|
||||
stream_info = " (streaming delta)" if delta else " (streaming complete)"
|
||||
|
||||
logger.info(
|
||||
"Generated response %s%s: output: %r, "
|
||||
"output_token_ids: %s, finish_reason: %s",
|
||||
request_id,
|
||||
stream_info,
|
||||
outputs,
|
||||
output_token_ids,
|
||||
finish_reason,
|
||||
)
|
||||
0
vllm_old/entrypoints/openai/__init__.py
Normal file
0
vllm_old/entrypoints/openai/__init__.py
Normal file
BIN
vllm_old/entrypoints/openai/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/openai/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_old/entrypoints/openai/__pycache__/cli_args.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/openai/__pycache__/cli_args.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_old/entrypoints/openai/__pycache__/protocol.cpython-312.pyc
Normal file
BIN
vllm_old/entrypoints/openai/__pycache__/protocol.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2096
vllm_old/entrypoints/openai/api_server.py
Normal file
2096
vllm_old/entrypoints/openai/api_server.py
Normal file
File diff suppressed because it is too large
Load Diff
302
vllm_old/entrypoints/openai/cli_args.py
Normal file
302
vllm_old/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
validate_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.constants import (
|
||||
H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: str | Sequence[str] | None,
|
||||
option_string: str | None = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: list[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, ""]: # Skip if item is None or empty string
|
||||
continue
|
||||
if "=" in item and "," not in item: # Old format: name=path
|
||||
name, path = item.split("=")
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FrontendArgs:
|
||||
"""Arguments for the OpenAI-compatible frontend server."""
|
||||
|
||||
host: str | None = None
|
||||
"""Host name."""
|
||||
port: int = 8000
|
||||
"""Port number."""
|
||||
uds: str | None = None
|
||||
"""Unix domain socket path. If set, host and port arguments are ignored."""
|
||||
uvicorn_log_level: Literal[
|
||||
"debug", "info", "warning", "error", "critical", "trace"
|
||||
] = "info"
|
||||
"""Log level for uvicorn."""
|
||||
disable_uvicorn_access_log: bool = False
|
||||
"""Disable uvicorn access log."""
|
||||
allow_credentials: bool = False
|
||||
"""Allow credentials."""
|
||||
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed origins."""
|
||||
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed methods."""
|
||||
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
|
||||
"""Allowed headers."""
|
||||
api_key: list[str] | None = None
|
||||
"""If provided, the server will require one of these keys to be presented in
|
||||
the header."""
|
||||
lora_modules: list[LoRAModulePath] | None = None
|
||||
"""LoRA modules configurations in either 'name=path' format or JSON format
|
||||
or JSON list format. Example (old format): `'name=path'` Example (new
|
||||
format): `{\"name\": \"name\", \"path\": \"lora_path\",
|
||||
\"base_model_name\": \"id\"}`"""
|
||||
chat_template: str | None = None
|
||||
"""The file path to the chat template, or the template in single-line form
|
||||
for the specified model."""
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: str | None = None
|
||||
"""The file path to the SSL key file."""
|
||||
ssl_certfile: str | None = None
|
||||
"""The file path to the SSL cert file."""
|
||||
ssl_ca_certs: str | None = None
|
||||
"""The CA certificates file."""
|
||||
enable_ssl_refresh: bool = False
|
||||
"""Refresh SSL Context when SSL certificate files change"""
|
||||
ssl_cert_reqs: int = int(ssl.CERT_NONE)
|
||||
"""Whether client certificate is required (see stdlib ssl module's)."""
|
||||
root_path: str | None = None
|
||||
"""FastAPI root_path when app is behind a path based routing proxy."""
|
||||
middleware: list[str] = field(default_factory=lambda: [])
|
||||
"""Additional ASGI middleware to apply to the app. We accept multiple
|
||||
--middleware arguments. The value should be an import path. If a function
|
||||
is provided, vLLM will add it to the server using
|
||||
`@app.middleware('http')`. If a class is provided, vLLM will
|
||||
add it to the server using `app.add_middleware()`."""
|
||||
return_tokens_as_token_ids: bool = False
|
||||
"""When `--max-logprobs` is specified, represents single tokens as
|
||||
strings of the form 'token_id:{token_id}' so that tokens that are not
|
||||
JSON-encodable can be identified."""
|
||||
disable_frontend_multiprocessing: bool = False
|
||||
"""If specified, will run the OpenAI frontend server in the same process as
|
||||
the model serving engine."""
|
||||
enable_request_id_headers: bool = False
|
||||
"""If specified, API server will add X-Request-Id header to responses."""
|
||||
enable_auto_tool_choice: bool = False
|
||||
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
|
||||
to specify which parser to use."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
tool_call_parser: str | None = None
|
||||
"""Select the tool call parser depending on the model that you're using.
|
||||
This is used to parse the model-generated tool call into OpenAI API format.
|
||||
Required for `--enable-auto-tool-choice`. You can choose any option from
|
||||
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
|
||||
tool_parser_plugin: str = ""
|
||||
"""Special the tool parser plugin write to parse the model-generated tool
|
||||
into OpenAI API format, the name register in this plugin can be used in
|
||||
`--tool-call-parser`."""
|
||||
tool_server: str | None = None
|
||||
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
|
||||
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
|
||||
purpose."""
|
||||
log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH
|
||||
"""Path to logging config JSON file for both vllm and uvicorn"""
|
||||
max_log_len: int | None = None
|
||||
"""Max number of prompt characters or prompt ID numbers being printed in
|
||||
log. The default of None means unlimited."""
|
||||
disable_fastapi_docs: bool = False
|
||||
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
|
||||
enable_prompt_tokens_details: bool = False
|
||||
"""If set to True, enable prompt_tokens_details in usage."""
|
||||
enable_server_load_tracking: bool = False
|
||||
"""If set to True, enable tracking server_load_metrics in the app state."""
|
||||
enable_force_include_usage: bool = False
|
||||
"""If set to True, including usage on every request."""
|
||||
enable_tokenizer_info_endpoint: bool = False
|
||||
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
|
||||
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
|
||||
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
|
||||
"""Maximum number of HTTP headers allowed in a request for h11 parser.
|
||||
Helps mitigate header abuse. Default: 256."""
|
||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||
"""If set to True, log the stack trace of error responses"""
|
||||
tokens_only: bool = False
|
||||
"""
|
||||
If set to True, only enable the Tokens In<>Out endpoint.
|
||||
This is intended for use in a Disaggregated Everything setup.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
from vllm.engine.arg_utils import get_kwargs
|
||||
|
||||
frontend_kwargs = get_kwargs(FrontendArgs)
|
||||
|
||||
# Special case: allowed_origins, allowed_methods, allowed_headers all
|
||||
# need json.loads type
|
||||
# Should also remove nargs
|
||||
frontend_kwargs["allowed_origins"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_methods"]["type"] = json.loads
|
||||
frontend_kwargs["allowed_headers"]["type"] = json.loads
|
||||
del frontend_kwargs["allowed_origins"]["nargs"]
|
||||
del frontend_kwargs["allowed_methods"]["nargs"]
|
||||
del frontend_kwargs["allowed_headers"]["nargs"]
|
||||
|
||||
# Special case: LoRA modules need custom parser action and
|
||||
# optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
|
||||
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
|
||||
|
||||
# Special case: Middleware needs to append action
|
||||
frontend_kwargs["middleware"]["action"] = "append"
|
||||
frontend_kwargs["middleware"]["type"] = str
|
||||
if "nargs" in frontend_kwargs["middleware"]:
|
||||
del frontend_kwargs["middleware"]["nargs"]
|
||||
frontend_kwargs["middleware"]["default"] = []
|
||||
|
||||
# Special case: Tool call parser shows built-in options.
|
||||
valid_tool_parsers = list(ToolParserManager.list_registered())
|
||||
parsers_str = ",".join(valid_tool_parsers)
|
||||
frontend_kwargs["tool_call_parser"]["metavar"] = (
|
||||
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
|
||||
)
|
||||
|
||||
frontend_group = parser.add_argument_group(
|
||||
title="Frontend",
|
||||
description=FrontendArgs.__doc__,
|
||||
)
|
||||
|
||||
for key, value in frontend_kwargs.items():
|
||||
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Create the CLI argument parser used by the OpenAI API server.
|
||||
|
||||
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
|
||||
register all arguments instead of manually enumerating them here. This
|
||||
avoids code duplication and keeps the argument definitions in one place.
|
||||
"""
|
||||
parser.add_argument(
|
||||
"model_tag",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="The model tag to serve (optional if specified in config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-server-count",
|
||||
"-asc",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many API server processes to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="Read CLI options from a config file. "
|
||||
"Must be a YAML with the following options: "
|
||||
"https://docs.vllm.ai/en/latest/configuration/serve_args.html",
|
||||
)
|
||||
parser = FrontendArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser")
|
||||
if args.enable_log_outputs and not args.enable_log_requests:
|
||||
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server"
|
||||
)
|
||||
return make_arg_parser(parser_for_docs)
|
||||
120
vllm_old/entrypoints/openai/orca_metrics.py
Normal file
120
vllm_old/entrypoints/openai/orca_metrics.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Utility functions that create ORCA endpoint load report response headers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_orca_header(
|
||||
metrics_format: str, named_metrics: list[tuple[str, float]]
|
||||
) -> Mapping[str, str] | None:
|
||||
"""
|
||||
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
|
||||
and adds custom metrics to named_metrics.
|
||||
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
|
||||
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
|
||||
|
||||
Parameters:
|
||||
- metrics_format (str): The format of the header ('TEXT', 'JSON').
|
||||
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
|
||||
and their corresponding double values.
|
||||
|
||||
Returns:
|
||||
- Optional[Mapping[str,str]]: A dictionary with header key as
|
||||
'endpoint-load-metrics' and values as the ORCA header strings with
|
||||
format prefix and data in with named_metrics in.
|
||||
"""
|
||||
|
||||
if metrics_format.lower() not in ["text", "json"]:
|
||||
logger.warning(
|
||||
"Warning: `%s` format is not supported in the ORCA response header",
|
||||
format,
|
||||
)
|
||||
return None
|
||||
|
||||
header = {}
|
||||
orca_report = {
|
||||
"named_metrics": {
|
||||
metric_name: value
|
||||
for metric_name, value in named_metrics
|
||||
if isinstance(metric_name, str) and isinstance(value, float)
|
||||
}
|
||||
}
|
||||
# output example:
|
||||
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
|
||||
if metrics_format.lower() == "text":
|
||||
native_http_header = ", ".join(
|
||||
[
|
||||
f"named_metrics.{metric_name}={value}"
|
||||
for metric_name, value in named_metrics
|
||||
if isinstance(metric_name, str) and isinstance(value, float)
|
||||
]
|
||||
)
|
||||
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
|
||||
|
||||
# output example:
|
||||
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
|
||||
elif metrics_format.lower() == "json":
|
||||
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
|
||||
|
||||
logger.info("Created ORCA header %s", header)
|
||||
|
||||
return header
|
||||
|
||||
|
||||
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
|
||||
"""
|
||||
Collects current metrics from Prometheus and returns some of them
|
||||
in the form of the `named_metrics` list for `create_orca_header()`.
|
||||
|
||||
Parameters:
|
||||
- None
|
||||
|
||||
Returns:
|
||||
- list[tuple[str, float]]: List of tuples of metric names and their values.
|
||||
"""
|
||||
named_metrics: list[tuple[str, float]] = []
|
||||
# Map from prometheus metric names to ORCA named metrics.
|
||||
prometheus_to_orca_metrics = {
|
||||
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
|
||||
"vllm:num_requests_waiting": "num_requests_waiting",
|
||||
}
|
||||
metrics = get_metrics_snapshot()
|
||||
for metric in metrics:
|
||||
orca_name = prometheus_to_orca_metrics.get(metric.name)
|
||||
# If this metric is mapped into ORCA, then add it to the report.
|
||||
# Note: Only Gauge metrics are currently supported.
|
||||
if orca_name is not None and isinstance(metric, Gauge):
|
||||
named_metrics.append((str(orca_name), float(metric.value)))
|
||||
return named_metrics
|
||||
|
||||
|
||||
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
|
||||
"""
|
||||
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
|
||||
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
|
||||
|
||||
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
|
||||
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
|
||||
|
||||
Parameters:
|
||||
- metrics_format (str): The format of the header ('TEXT', 'JSON').
|
||||
|
||||
Returns:
|
||||
- Optional[Mapping[str,str]]: A dictionary with header key as
|
||||
'endpoint-load-metrics' and values as the ORCA header strings with
|
||||
format prefix and data in with named_metrics in.
|
||||
"""
|
||||
if not metrics_format:
|
||||
return None
|
||||
# Get named metrics from prometheus.
|
||||
named_metrics = get_named_metrics_from_prometheus()
|
||||
return create_orca_header(metrics_format, named_metrics)
|
||||
3299
vllm_old/entrypoints/openai/protocol.py
Normal file
3299
vllm_old/entrypoints/openai/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
547
vllm_old/entrypoints/openai/run_batch.py
Normal file
547
vllm_old/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,547 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
RerankResponse,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-tmp-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to store the output file before uploading it "
|
||||
"to the output URL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response-role",
|
||||
type=optional_type(str),
|
||||
default="assistant",
|
||||
help="The role name to return if `request.add_generation_prompt=True`.",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-log-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of prompt characters or prompt "
|
||||
"ID numbers being printed in log."
|
||||
"\n\nDefault: Unlimited",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-metrics", action="store_true", help="Enable Prometheus metrics"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-force-include-usage",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set to True, include usage on every request "
|
||||
"(even when stream_options is not specified)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
|
||||
return make_arg_parser(parser).parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: tqdm | None = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
self._pbar = tqdm(
|
||||
total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_local_file(
|
||||
output_path: str, batch_outputs: list[BatchRequestOutput]
|
||||
) -> None:
|
||||
"""
|
||||
Write the responses to a local file.
|
||||
output_path: The path to write the responses to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
"""
|
||||
# We should make this async, but as long as run_batch runs as a
|
||||
# standalone program, blocking the event loop won't affect performance.
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=f)
|
||||
|
||||
|
||||
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
|
||||
"""
|
||||
Upload a local file to a URL.
|
||||
output_url: The URL to upload the file to.
|
||||
data_or_file: Either the data to upload or the path to the file to upload.
|
||||
from_file: If True, data_or_file is the path to the file to upload.
|
||||
"""
|
||||
# Timeout is a common issue when uploading large files.
|
||||
# We retry max_retries times before giving up.
|
||||
max_retries = 5
|
||||
# Number of seconds to wait before retrying.
|
||||
delay = 5
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# We increase the timeout to 1000 seconds to allow
|
||||
# for large files (default is 300).
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=1000)
|
||||
) as session:
|
||||
if from_file:
|
||||
with open(data_or_file, "rb") as file:
|
||||
async with session.put(output_url, data=file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to upload file.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}"
|
||||
)
|
||||
else:
|
||||
async with session.put(output_url, data=data_or_file) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to upload data.\n"
|
||||
f"Status: {response.status}\n"
|
||||
f"Response: {response.text()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.error(
|
||||
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
|
||||
attempt,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
|
||||
) from e
|
||||
|
||||
|
||||
async def write_file(
|
||||
path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Write batch_outputs to a file or upload to a URL.
|
||||
path_or_url: The path or URL to write batch_outputs to.
|
||||
batch_outputs: The list of batch outputs to write.
|
||||
output_tmp_dir: The directory to store the output file before uploading it
|
||||
to the output URL.
|
||||
"""
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
if output_tmp_dir is None:
|
||||
logger.info("Writing outputs to memory buffer")
|
||||
output_buffer = StringIO()
|
||||
for o in batch_outputs:
|
||||
print(o.model_dump_json(), file=output_buffer)
|
||||
output_buffer.seek(0)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(
|
||||
path_or_url,
|
||||
output_buffer.read().strip().encode("utf-8"),
|
||||
from_file=False,
|
||||
)
|
||||
else:
|
||||
# Write responses to a temporary file and then upload it to the URL.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=output_tmp_dir,
|
||||
prefix="tmp_batch_output_",
|
||||
suffix=".jsonl",
|
||||
) as f:
|
||||
logger.info("Writing outputs to temporary local file %s", f.name)
|
||||
await write_local_file(f.name, batch_outputs)
|
||||
logger.info("Uploading outputs to %s", path_or_url)
|
||||
await upload_data(path_or_url, f.name, from_file=True)
|
||||
else:
|
||||
logger.info("Writing outputs to local file %s", path_or_url)
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
def make_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str
|
||||
) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str
|
||||
) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(
|
||||
serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker,
|
||||
) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.error.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode"
|
||||
)
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
def validate_run_batch_args(args):
|
||||
valid_reasoning_parsers = ReasoningParserManager.list_registered()
|
||||
if (
|
||||
reasoning_parser := args.structured_outputs_config.reasoning_parser
|
||||
) and reasoning_parser not in valid_reasoning_parsers:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
|
||||
)
|
||||
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
model_config = engine_client.model_config
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
)
|
||||
|
||||
openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
if "embed" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
enable_serving_reranking = (
|
||||
"classify" in supported_tasks
|
||||
and getattr(model_config.hf_config, "num_labels", 0) == 1
|
||||
)
|
||||
|
||||
openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
if ("embed" in supported_tasks or enable_serving_reranking)
|
||||
else None
|
||||
)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: list[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
chat_handler_fn = (
|
||||
openai_serving_chat.create_chat_completion
|
||||
if openai_serving_chat is not None
|
||||
else None
|
||||
)
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Chat Completions API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
embed_handler_fn = (
|
||||
openai_serving_embedding.create_embedding
|
||||
if openai_serving_embedding is not None
|
||||
else None
|
||||
)
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = (
|
||||
openai_serving_scores.create_score
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
)
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = (
|
||||
openai_serving_scores.do_rerank
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
)
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
)
|
||||
)
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
async def main(args: Namespace):
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
validate_run_batch_args(args)
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
await run_batch(engine_client, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
1772
vllm_old/entrypoints/openai/serving_chat.py
Normal file
1772
vllm_old/entrypoints/openai/serving_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
235
vllm_old/entrypoints/openai/serving_classification.py
Normal file
235
vllm_old/entrypoints/openai/serving_classification.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
chat_template: str | None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
trust_request_chat_template: bool
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
try:
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
request_obj = ctx.request
|
||||
|
||||
if isinstance(request_obj, ClassificationChatRequest):
|
||||
chat_request = request_obj
|
||||
messages = chat_request.messages
|
||||
trust_request_chat_template = getattr(
|
||||
self,
|
||||
"trust_request_chat_template",
|
||||
False,
|
||||
)
|
||||
ret = self._validate_chat_template(
|
||||
request_chat_template=chat_request.chat_template,
|
||||
chat_template_kwargs=chat_request.chat_template_kwargs,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
if ret:
|
||||
return ret
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
cast(ChatCompletionRequest, chat_request),
|
||||
ctx.tokenizer,
|
||||
messages,
|
||||
chat_template=(
|
||||
chat_request.chat_template
|
||||
or getattr(self, "chat_template", None)
|
||||
),
|
||||
chat_template_content_format=cast(
|
||||
ChatTemplateContentFormatOption,
|
||||
getattr(self, "chat_template_content_format", "auto"),
|
||||
),
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=chat_request.add_special_tokens,
|
||||
)
|
||||
ctx.engine_prompts = engine_prompts
|
||||
|
||||
elif isinstance(request_obj, ClassificationCompletionRequest):
|
||||
completion_request = request_obj
|
||||
input_data = completion_request.input
|
||||
if input_data in (None, ""):
|
||||
return self.create_error_response(
|
||||
"Input or messages must be provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
if isinstance(input_data, list) and not input_data:
|
||||
ctx.engine_prompts = []
|
||||
return None
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
prompt_input = cast(str | list[str], input_data)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=prompt_input,
|
||||
config=self._build_render_config(completion_request),
|
||||
)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"Invalid classification request type",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = getattr(self.model_config.hf_config, "id2label", {}).get(
|
||||
predicted_index
|
||||
)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> PoolingParams | ErrorResponse:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("classify", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
715
vllm_old/entrypoints/openai/serving_completion.py
Normal file
715
vllm_old/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,715 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
# set up logits processors
|
||||
self.logits_processors = self.model_config.logits_processors
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info(
|
||||
"Using default completion sampling params from %s: %s",
|
||||
source,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response("suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response("Echo is unsupported with prompt embeds.")
|
||||
|
||||
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds."
|
||||
)
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text, prompt_token_ids, prompt_embeds = (
|
||||
self._get_prompt_components(engine_prompt)
|
||||
)
|
||||
|
||||
input_length = None
|
||||
if prompt_token_ids is not None:
|
||||
input_length = len(prompt_token_ids)
|
||||
elif prompt_embeds is not None:
|
||||
input_length = len(prompt_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params: SamplingParams | BeamSearchParams
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params
|
||||
)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
validate_logits_processors_parameters(
|
||||
self.logits_processors,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (
|
||||
request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search
|
||||
)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = (
|
||||
None
|
||||
if is_embeds_prompt(engine_prompt)
|
||||
else engine_prompt.get("prompt")
|
||||
)
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[TokensPrompt | EmbedsPrompt],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
num_cached_tokens = None
|
||||
first_iteration = True
|
||||
|
||||
stream_options = request.stream_options
|
||||
include_usage, include_continuous_usage = should_include_usage(
|
||||
stream_options, self.enable_force_include_usage
|
||||
)
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = (
|
||||
None
|
||||
if is_embeds_prompt(engine_prompt)
|
||||
else engine_prompt.get("prompt")
|
||||
)
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
# Useful when request.return_token_ids is True
|
||||
# Returning prompt token IDs shares the same logic
|
||||
# with the echo implementation.
|
||||
prompt_token_ids_to_return: list[int] | None = None
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids,
|
||||
*output.token_ids,
|
||||
]
|
||||
out_logprobs = [
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# has_echoed[i] is reused here to indicate whether
|
||||
# we have already returned the prompt token IDs.
|
||||
if not has_echoed[i] and request.return_token_ids:
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
|
||||
if (
|
||||
not delta_text
|
||||
and not delta_token_ids
|
||||
and not previous_num_tokens[i]
|
||||
):
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
prompt_token_ids=prompt_token_ids_to_return,
|
||||
token_ids=(
|
||||
as_list(output.token_ids)
|
||||
if request.return_token_ids
|
||||
else None
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if include_continuous_usage:
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens
|
||||
)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=final_usage_info,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: list[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
kv_transfer_params = None
|
||||
last_final_res = None
|
||||
for final_res in final_res_batch:
|
||||
last_final_res = final_res
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
else:
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_token_ids=(
|
||||
prompt_token_ids if request.return_token_ids else None
|
||||
),
|
||||
token_ids=(
|
||||
as_list(output.token_ids) if request.return_token_ids else None
|
||||
),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.enable_prompt_tokens_details
|
||||
and last_final_res
|
||||
and last_final_res.num_cached_tokens
|
||||
):
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=last_final_res.num_cached_tokens
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
if final_res_batch:
|
||||
kv_transfer_params = final_res_batch[0].kv_transfer_params
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: bool | None = None,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: list[int] = []
|
||||
out_token_logprobs: list[float | None] = []
|
||||
out_tokens: list[str] = []
|
||||
out_top_logprobs: list[dict[str, float] | None] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
should_return_as_token_id = (
|
||||
return_as_token_id
|
||||
if return_as_token_id is not None
|
||||
else self.return_tokens_as_token_ids
|
||||
)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append(
|
||||
{
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
): max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
}
|
||||
)
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
max_input_length: int | None = None,
|
||||
) -> RenderConfig:
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo and not request.return_token_ids),
|
||||
)
|
||||
695
vllm_old/entrypoints/openai/serving_embedding.py
Normal file
695
vllm_old/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,695 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from fastapi.responses import Response
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext,
|
||||
TextTokensPrompt,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (
|
||||
EmbeddingRequestOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
RequestOutput,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.serial_utils import (
|
||||
EmbedDType,
|
||||
EncodingFormat,
|
||||
Endianness,
|
||||
encode_pooling_bytes,
|
||||
encode_pooling_output,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
|
||||
# Avoid repeated attribute lookups
|
||||
self.supports_chunked_processing = bool(
|
||||
pooler_config and pooler_config.enable_chunked_processing
|
||||
)
|
||||
self.max_embed_len = (
|
||||
pooler_config.max_embed_len
|
||||
if pooler_config and pooler_config.max_embed_len
|
||||
else None
|
||||
)
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template or ctx.chat_template,
|
||||
chat_template_content_format=ctx.chat_template_content_format,
|
||||
add_generation_prompt=ctx.request.add_generation_prompt,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
return RenderConfig(
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> EmbeddingResponse | Response | ErrorResponse:
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
|
||||
|
||||
encoding_format: EncodingFormat = ctx.request.encoding_format
|
||||
embed_dtype: EmbedDType = ctx.request.embed_dtype
|
||||
endianness: Endianness = ctx.request.endianness
|
||||
|
||||
def encode_float_base64():
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=encode_pooling_output(
|
||||
final_res,
|
||||
encoding_format=encoding_format,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def encode_bytes():
|
||||
body, items, usage = encode_pooling_bytes(
|
||||
pooling_outputs=final_res_batch_checked,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"id": ctx.request_id,
|
||||
"created": ctx.created_time,
|
||||
"model": ctx.model_name,
|
||||
"data": items,
|
||||
"usage": usage,
|
||||
}
|
||||
return EmbeddingBytesResponse(
|
||||
body=body,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return encode_float_base64()
|
||||
elif encoding_format == "bytes":
|
||||
return encode_bytes()
|
||||
else:
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _get_max_position_embeddings(self) -> int:
|
||||
"""Get the model's effective maximum sequence length for chunking."""
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def _should_use_chunked_processing(self, request) -> bool:
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return (
|
||||
isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
|
||||
and self.supports_chunked_processing
|
||||
)
|
||||
|
||||
async def _process_chunked_request(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
original_prompt: TextTokensPrompt,
|
||||
pooling_params,
|
||||
trace_headers,
|
||||
prompt_idx: int,
|
||||
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
|
||||
"""Process a single prompt using chunked processing."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
token_ids = original_prompt["prompt_token_ids"]
|
||||
|
||||
# Split into chunks using max_position_embeddings
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
# Process all chunks for MEAN aggregation
|
||||
for chunk_idx, chunk_tokens in enumerate(
|
||||
chunk_list(token_ids, max_pos_embeddings)
|
||||
):
|
||||
# Create a request ID for this chunk
|
||||
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Create chunk request prompt for logging
|
||||
chunk_text = ""
|
||||
chunk_request_prompt = TextTokensPrompt(
|
||||
prompt=chunk_text, prompt_token_ids=chunk_tokens
|
||||
)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(
|
||||
chunk_request_id,
|
||||
chunk_request_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
|
||||
return generators
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
"""Override to support chunked processing for embedding requests."""
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
|
||||
# Check if chunked processing is enabled for pooling models
|
||||
enable_chunked = self._should_use_chunked_processing(request)
|
||||
|
||||
# Use max_position_embeddings for chunked processing decisions
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.max_embed_len is not None:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input."
|
||||
)
|
||||
|
||||
chunked_processing_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input "
|
||||
"or enable chunked processing."
|
||||
)
|
||||
|
||||
# Check if input exceeds max length
|
||||
if token_num > max_length_value:
|
||||
raise ValueError(
|
||||
validation_error_msg.format(
|
||||
length_type=length_type,
|
||||
max_length_value=max_length_value,
|
||||
token_num=token_num,
|
||||
)
|
||||
)
|
||||
|
||||
# Check for chunked processing
|
||||
# when exceeding max_position_embeddings
|
||||
if token_num > max_pos_embeddings:
|
||||
if enable_chunked:
|
||||
# Allow long inputs when chunked processing is enabled
|
||||
logger.info(
|
||||
"Input length %s exceeds max_position_embeddings "
|
||||
"%s, will use chunked processing",
|
||||
token_num,
|
||||
max_pos_embeddings,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
chunked_processing_error_msg.format(
|
||||
length_type="maximum position embeddings length",
|
||||
max_length_value=max_pos_embeddings,
|
||||
token_num=token_num,
|
||||
)
|
||||
)
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# For other request types, use the parent's implementation
|
||||
return super()._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _is_text_tokens_prompt(self, prompt) -> bool:
|
||||
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
|
||||
return (
|
||||
isinstance(prompt, dict)
|
||||
and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt
|
||||
)
|
||||
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_index: int,
|
||||
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
|
||||
"""Create a generator for a single prompt using standard processing."""
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Override to support chunked processing."""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
# If no chunked processing needed, delegate to parent class
|
||||
if not use_chunked:
|
||||
return await super()._prepare_generators(ctx)
|
||||
|
||||
# Custom logic for chunked processing
|
||||
generators: list[
|
||||
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
|
||||
] = []
|
||||
|
||||
try:
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
# Verify and set the task for pooling params
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if self._is_text_tokens_prompt(engine_prompt):
|
||||
# Cast to TextTokensPrompt since we've verified
|
||||
# prompt_token_ids
|
||||
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||
if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
ctx, text_tokens_prompt, pooling_params, trace_headers, i
|
||||
)
|
||||
generators.extend(chunk_generators)
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt, pooling_params, trace_headers, i
|
||||
)
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
# Check if we used chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response("Result generator not available")
|
||||
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
prompt_aggregators: dict[int, dict[str, Any]] = {}
|
||||
short_prompts_results: dict[int, PoolingRequestOutput] = {}
|
||||
|
||||
async for result_idx, result in ctx.result_generator:
|
||||
if "-chunk-" in result.request_id:
|
||||
# Extract prompt_idx from chunked request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
prompt_idx = int(parts[parts.index("prompt") + 1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback: extract from result_idx if parsing fails
|
||||
prompt_idx = result_idx
|
||||
|
||||
# Initialize aggregator for this prompt if needed
|
||||
if prompt_idx not in prompt_aggregators:
|
||||
prompt_aggregators[prompt_idx] = {
|
||||
"weighted_sum": None,
|
||||
"total_weight": 0,
|
||||
"chunk_count": 0,
|
||||
"request_id": result.request_id.split("-chunk-")[0],
|
||||
}
|
||||
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
# MEAN pooling with online weighted averaging
|
||||
# Ensure result is PoolingRequestOutput
|
||||
# for embedding processing
|
||||
if not isinstance(result, PoolingRequestOutput):
|
||||
return self.create_error_response(
|
||||
f"Expected PoolingRequestOutput for "
|
||||
f"chunked embedding, got "
|
||||
f"{type(result).__name__}"
|
||||
)
|
||||
|
||||
# Handle both PoolingOutput and
|
||||
# EmbeddingOutput types
|
||||
if hasattr(result.outputs, "data"):
|
||||
# PoolingOutput case
|
||||
embedding_data = result.outputs.data
|
||||
elif hasattr(result.outputs, "embedding"):
|
||||
# EmbeddingOutput case -
|
||||
# convert embedding list to tensor
|
||||
embedding_data = result.outputs.embedding
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Unsupported output type: {type(result.outputs).__name__}"
|
||||
)
|
||||
|
||||
if not isinstance(embedding_data, torch.Tensor):
|
||||
embedding_data = torch.tensor(
|
||||
embedding_data, dtype=torch.float32
|
||||
)
|
||||
|
||||
if result.prompt_token_ids is None:
|
||||
return self.create_error_response(
|
||||
"prompt_token_ids cannot be None for chunked processing"
|
||||
)
|
||||
weight = len(result.prompt_token_ids)
|
||||
|
||||
weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
|
||||
|
||||
if aggregator["weighted_sum"] is None:
|
||||
# First chunk
|
||||
aggregator["weighted_sum"] = weighted_embedding
|
||||
else:
|
||||
# Accumulate
|
||||
aggregator["weighted_sum"] += weighted_embedding
|
||||
|
||||
aggregator["total_weight"] += weight
|
||||
aggregator["chunk_count"] += 1
|
||||
else:
|
||||
# Non-chunked result - extract prompt_idx from request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
# Last part should be prompt index
|
||||
prompt_idx = int(parts[-1])
|
||||
except (ValueError, IndexError):
|
||||
prompt_idx = result_idx # Fallback to result_idx
|
||||
|
||||
short_prompts_results[prompt_idx] = cast(
|
||||
PoolingRequestOutput, result
|
||||
)
|
||||
|
||||
# Finalize aggregated results
|
||||
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if prompt_idx in prompt_aggregators:
|
||||
# Finalize MEAN aggregation for this chunked prompt
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
weighted_sum = aggregator["weighted_sum"]
|
||||
total_weight = aggregator["total_weight"]
|
||||
|
||||
if (
|
||||
weighted_sum is not None
|
||||
and isinstance(weighted_sum, torch.Tensor)
|
||||
and isinstance(total_weight, (int, float))
|
||||
and total_weight > 0
|
||||
):
|
||||
# Compute final mean embedding
|
||||
final_embedding = weighted_sum / total_weight
|
||||
|
||||
# Create a PoolingRequestOutput
|
||||
# for the aggregated result
|
||||
pooling_output_data = PoolingOutput(data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||
if not self._is_text_tokens_prompt(original_prompt):
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
|
||||
)
|
||||
|
||||
original_token_ids = cast(TextTokensPrompt, original_prompt)[
|
||||
"prompt_token_ids"
|
||||
]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator["request_id"],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
num_cached_tokens=0,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
final_res_batch.append(pooling_request_output)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Failed to aggregate chunks for prompt {prompt_idx}"
|
||||
)
|
||||
elif prompt_idx in short_prompts_results:
|
||||
final_res_batch.append(
|
||||
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
|
||||
)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Result not found for prompt {prompt_idx}"
|
||||
)
|
||||
|
||||
ctx.final_res_batch = cast(
|
||||
list[RequestOutput | PoolingRequestOutput], final_res_batch
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> EmbeddingResponse | ErrorResponse:
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||
)
|
||||
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> PoolingParams | ErrorResponse:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
return await super()._preprocess(ctx)
|
||||
1433
vllm_old/entrypoints/openai/serving_engine.py
Normal file
1433
vllm_old/entrypoints/openai/serving_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
304
vllm_old/entrypoints/openai/serving_models.py
Normal file
304
vllm_old/entrypoints/openai/serving_models.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.utils.counter import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: str | None = None
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
Handles the routes:
|
||||
- /v1/models
|
||||
- /v1/load_lora_adapter
|
||||
- /v1/unload_lora_adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: list[LoRAModulePath] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name)
|
||||
)
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.processor = self.engine_client.processor
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoRAAdapterRequest(
|
||||
lora_path=lora.path, lora_name=lora.name
|
||||
)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name
|
||||
)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.error.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: LoRARequest | None = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
model_cards = [
|
||||
ModelCard(
|
||||
id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(
|
||||
id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
else self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def load_lora_adapter(
|
||||
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
|
||||
) -> ErrorResponse | str:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_path = request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path
|
||||
)
|
||||
if base_model_name is not None and self.is_base_model(base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except Exception as e:
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(
|
||||
message=str(e), err_type=error_type, status_code=status_code
|
||||
)
|
||||
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
|
||||
)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self, request: UnloadLoRAAdapterRequest
|
||||
) -> ErrorResponse | str:
|
||||
lora_name = request.lora_name
|
||||
|
||||
# Ensure atomicity based on the lora name
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Safe to delete now since we hold the lock
|
||||
del self.lora_requests[lora_name]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoRAAdapterRequest
|
||||
) -> ErrorResponse | None:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if request.lora_name in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=f"The lora adapter '{request.lora_name}' has already been "
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self, request: UnloadLoRAAdapterRequest
|
||||
) -> ErrorResponse | None:
|
||||
# Check if 'lora_name' is not provided return an error
|
||||
if not request.lora_name:
|
||||
return create_error_response(
|
||||
message="'lora_name' needs to be provided to unload a LoRA adapter.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if request.lora_name not in self.lora_requests:
|
||||
return create_error_response(
|
||||
message=f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
if lora_name in self.lora_requests:
|
||||
return self.lora_requests[lora_name]
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name,
|
||||
resolver.__class__.__name__,
|
||||
)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.",
|
||||
lora_name,
|
||||
resolver.__class__.__name__,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(
|
||||
f"LoRA adapter '{lora_name}' was found but could not be loaded."
|
||||
),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
) -> ErrorResponse:
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
|
||||
)
|
||||
346
vllm_old/entrypoints/openai/serving_pooling.py
Normal file
346
vllm_old/entrypoints/openai/serving_pooling.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Final, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.tasks import PoolingTask, SupportedTask
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import (
|
||||
EmbedDType,
|
||||
EncodingFormat,
|
||||
Endianness,
|
||||
encode_pooling_bytes,
|
||||
encode_pooling_output,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = self.models.model_name()
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens
|
||||
)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
if not isinstance(engine_prompts, Sequence) or isinstance(
|
||||
engine_prompts, (str, bytes, bytearray)
|
||||
):
|
||||
engine_prompts = [engine_prompts]
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None and isinstance(
|
||||
request, IOProcessorRequest
|
||||
)
|
||||
pooling_params = self.io_processor.validate_or_generate_params()
|
||||
else:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
pooling_task: PoolingTask
|
||||
if request.task is None:
|
||||
if "token_embed" in self.supported_tasks:
|
||||
pooling_task = "token_embed"
|
||||
elif "token_classify" in self.supported_tasks:
|
||||
pooling_task = "token_classify"
|
||||
elif "plugin" in self.supported_tasks:
|
||||
pooling_task = "plugin"
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"pooling_task must be one of {self.supported_tasks}."
|
||||
)
|
||||
else:
|
||||
pooling_task = request.task
|
||||
|
||||
if pooling_task not in self.supported_tasks:
|
||||
return self.create_error_response(
|
||||
f"Task {pooling_task} is not supported, it"
|
||||
f" must be one of {self.supported_tasks}."
|
||||
)
|
||||
|
||||
try:
|
||||
pooling_params.verify(pooling_task, self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
request.embed_dtype,
|
||||
request.endianness,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: EncodingFormat,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingResponse | PoolingBytesResponse:
|
||||
def encode_float_base64():
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=encode_pooling_output(
|
||||
final_res,
|
||||
encoding_format=encoding_format,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def encode_bytes():
|
||||
body, items, usage = encode_pooling_bytes(
|
||||
pooling_outputs=final_res_batch,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"id": request_id,
|
||||
"created": created_time,
|
||||
"model": model_name,
|
||||
"data": items,
|
||||
"usage": usage,
|
||||
}
|
||||
return PoolingBytesResponse(
|
||||
body=body,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return encode_float_base64()
|
||||
elif encoding_format == "bytes":
|
||||
return encode_bytes()
|
||||
else:
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
2021
vllm_old/entrypoints/openai/serving_responses.py
Normal file
2021
vllm_old/entrypoints/openai/serving_responses.py
Normal file
File diff suppressed because it is too large
Load Diff
503
vllm_old/entrypoints/openai/serving_score.py
Normal file
503
vllm_old/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResult,
|
||||
RerankUsage,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
_validate_score_input_lens,
|
||||
compress_token_type_ids,
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
tokenize_async = make_async(
|
||||
tokenizer.__call__, executor=self._tokenizer_executor
|
||||
)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(
|
||||
request, tok_result["input_ids"], input_text
|
||||
)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_texts_1: list[PoolingRequestOutput] = []
|
||||
emb_texts_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(texts_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_1.append(emb)
|
||||
|
||||
for i in range(len(texts_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_2.append(emb)
|
||||
|
||||
if len(emb_texts_1) == 1:
|
||||
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(
|
||||
tokenizer=tokenizer, embed_1=emb_texts_1, embed_2=emb_texts_2
|
||||
)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
model_config = self.model_config
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
data_1: list[str] | list[ScoreContentPartParam],
|
||||
data_2: list[str] | list[ScoreContentPartParam],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_score, executor=self._tokenizer_executor
|
||||
)
|
||||
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2,
|
||||
)
|
||||
for t1, t2 in input_pairs
|
||||
)
|
||||
)
|
||||
|
||||
for full_prompt, engine_prompt in preprocessed_prompts:
|
||||
request_prompts.append(full_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
default_pooling_params.verify("score", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
request_prompts[i],
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
else:
|
||||
pooling_params = default_pooling_params
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
|
||||
engine_prompts
|
||||
)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: list[str] | str | ScoreMultiModalParam,
|
||||
data_2: list[str] | str | ScoreMultiModalParam,
|
||||
request: ScoreRequest | RerankRequest,
|
||||
request_id: str,
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens, tokenization_kwargs
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
if not self.model_config.is_multimodal_model and (
|
||||
isinstance(data_1, dict) or isinstance(data_2, dict)
|
||||
):
|
||||
raise ValueError(
|
||||
f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501
|
||||
)
|
||||
|
||||
if isinstance(data_1, str):
|
||||
data_1 = [data_1]
|
||||
elif isinstance(data_1, dict):
|
||||
data_1 = data_1.get("content") # type: ignore[assignment]
|
||||
|
||||
if isinstance(data_2, str):
|
||||
data_2 = [data_2]
|
||||
elif isinstance(data_2, dict):
|
||||
data_2 = data_2.get("content") # type: ignore[assignment]
|
||||
|
||||
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
data_1=data_1, # type: ignore[arg-type]
|
||||
data_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=data_1, # type: ignore[arg-type]
|
||||
texts_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> ScoreResponse | ErrorResponse:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.text_1,
|
||||
request.text_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def do_rerank(
|
||||
self, request: RerankRequest, raw_request: Request | None = None
|
||||
) -> RerankResponse | ErrorResponse:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
top_n = (
|
||||
request.top_n
|
||||
if request.top_n > 0
|
||||
else (
|
||||
len(documents)
|
||||
if isinstance(documents, list)
|
||||
else len(documents["content"])
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
documents: list[str] | ScoreMultiModalParam,
|
||||
top_n: int,
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx])
|
||||
if isinstance(documents, list)
|
||||
else RerankDocument(multi_modal=documents["content"][idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens),
|
||||
)
|
||||
203
vllm_old/entrypoints/openai/serving_tokenization.py
Normal file
203
vllm_old/entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Final
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TokenizerInfoResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> TokenizeResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (
|
||||
None
|
||||
if request.tools is None
|
||||
else [tool.model_dump() for tool in request.tools]
|
||||
)
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(
|
||||
request_id, engine_prompt, params=None, lora_request=lora_request
|
||||
)
|
||||
|
||||
if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(
|
||||
tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> DetokenizeResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
self._log_inputs(
|
||||
request_id, request.tokens, params=None, lora_request=lora_request
|
||||
)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
|
||||
async def get_tokenizer_info(
|
||||
self,
|
||||
) -> TokenizerInfoResponse | ErrorResponse:
|
||||
"""Get comprehensive tokenizer information."""
|
||||
try:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||
return TokenizerInfoResponse(**info)
|
||||
except Exception as e:
|
||||
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
|
||||
return RenderConfig(add_special_tokens=request.add_special_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
tokenizer: AnyTokenizer
|
||||
chat_template: str | None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the tokenizer configuration."""
|
||||
return self._get_tokenizer_config()
|
||||
|
||||
def _get_tokenizer_config(self) -> dict[str, Any]:
|
||||
"""Get tokenizer configuration directly from the tokenizer object."""
|
||||
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
|
||||
|
||||
# Remove file path fields
|
||||
config.pop("vocab_file", None)
|
||||
config.pop("merges_file", None)
|
||||
|
||||
config = self._make_json_serializable(config)
|
||||
config["tokenizer_class"] = type(self.tokenizer).__name__
|
||||
if self.chat_template:
|
||||
config["chat_template"] = self.chat_template
|
||||
return config
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert any non-JSON-serializable objects to serializable format."""
|
||||
if hasattr(obj, "content"):
|
||||
return obj.content
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
269
vllm_old/entrypoints/openai/serving_tokens.py
Normal file
269
vllm_old/entrypoints/openai/serving_tokens.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
# yapf: disable
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb,
|
||||
ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingTokens(OpenAIServing):
|
||||
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
force_no_detokenize: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_log_outputs: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
self.force_no_detokenize = force_no_detokenize
|
||||
if force_no_detokenize:
|
||||
logger.info("Tokens-only mode is enabled, skipping detokenization "
|
||||
"step for incoming requests.")
|
||||
|
||||
async def serve_tokens(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
raw_request: Request | None = None
|
||||
) -> GenerateResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
lora_request = None
|
||||
lora_request = self._maybe_get_adapters(request,
|
||||
supports_default_mm_loras=True)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
request_id = "generate-tokens-" \
|
||||
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
|
||||
# completed
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
|
||||
if request.features is not None:
|
||||
engine_prompt["multi_modal_data"] = None
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
||||
try:
|
||||
sampling_params = request.sampling_params
|
||||
if self.force_no_detokenize:
|
||||
sampling_params.detokenize = False
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.token_ids,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# TODO(NickLucche): Implement streaming response
|
||||
|
||||
try:
|
||||
assert result_generator is not None
|
||||
return await self.serve_tokens_full_generator(
|
||||
request, result_generator, request_id, model_name,
|
||||
request_metadata)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def serve_tokens_full_generator(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | GenerateResponse:
|
||||
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput | None = None
|
||||
sampling_params: SamplingParams = request.sampling_params
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: list[GenerateResponseChoice] = []
|
||||
num_generated_tokens = 0
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# This is top_logprobs in completions API
|
||||
if sampling_params.logprobs:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_tokens_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=sampling_params.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = GenerateResponseChoice(
|
||||
index=output.index,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if output.finish_reason else "stop",
|
||||
token_ids=as_list(output.token_ids))
|
||||
|
||||
choices.append(choice_data)
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
num_generated_tokens)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
# This info is not available at the /coordinator level
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = GenerateResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Log complete response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
for choice in choices:
|
||||
# Get the corresponding output token IDs
|
||||
output_token_ids = None
|
||||
if choice.index < len(final_res.outputs):
|
||||
output_token_ids = final_res.outputs[
|
||||
choice.index].token_ids
|
||||
|
||||
if output_token_ids:
|
||||
# Log token_ids only.
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs="",
|
||||
output_token_ids=output_token_ids,
|
||||
finish_reason=choice.finish_reason,
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _create_tokens_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int | None = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token = f"token_id:{token_id}"
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None or step_top_logprobs.get(
|
||||
token_id) is None:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(token=token, ))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
top_logprobs=[
|
||||
ChatCompletionLogProb(
|
||||
token=token,
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
) for i, p in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs
|
||||
and i < num_output_top_logprobs
|
||||
]))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
148
vllm_old/entrypoints/openai/serving_transcription.py
Normal file
148
vllm_old/entrypoints/openai/serving_transcription.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
|
||||
) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
|
||||
) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self,
|
||||
request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
405
vllm_old/entrypoints/openai/speech_to_text.py
Normal file
405
vllm_old/entrypoints/openai/speech_to_text.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import cached_property
|
||||
from typing import Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationStreamResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.task_type = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
self.model_config, task_type
|
||||
)
|
||||
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = (
|
||||
self.model_cls.validate_language(request.to_language)
|
||||
if request.to_language
|
||||
else None
|
||||
)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (
|
||||
self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s
|
||||
)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[T],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> T | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ["text", "json"]:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`"
|
||||
)
|
||||
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(
|
||||
request, list_result_generator, request_id, request_metadata, duration_s
|
||||
)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
|
||||
if self.task_type == "transcribe":
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
final_response = cast(T, response_class(text=text, usage=usage))
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
final_response = cast(T, response_class(text=text)) # type: ignore[call-arg]
|
||||
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
|
||||
| type[TranslationResponseStreamChoice],
|
||||
stream_response_class: type[TranscriptionStreamResponse]
|
||||
| type[TranslationStreamResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = self.enable_force_include_usage or request.stream_include_usage
|
||||
include_continuous_usage = (
|
||||
request.stream_continuous_usage_stats
|
||||
if include_usage and request.stream_continuous_usage_stats
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config, self.model_config
|
||||
):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
|
||||
chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(
|
||||
self, audio_data: np.ndarray, sample_rate: int
|
||||
) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start, search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
min_energy_window = self.asr_config.min_energy_split_window_size
|
||||
assert min_energy_window is not None
|
||||
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||
window = segment[i : i + min_energy_window]
|
||||
energy = (window**2).mean() ** 0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
142
vllm_old/entrypoints/openai/tool_parsers/__init__.py
Normal file
142
vllm_old/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
|
||||
__all__ = ["ToolParser", "ToolParserManager"]
|
||||
|
||||
|
||||
"""
|
||||
Register a lazy module mapping.
|
||||
|
||||
Example:
|
||||
ToolParserManager.register_lazy_module(
|
||||
name="kimi_k2",
|
||||
module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser",
|
||||
class_name="KimiK2ToolParser",
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_TOOL_PARSERS_TO_REGISTER = {
|
||||
"deepseek_v3": ( # name
|
||||
"deepseekv3_tool_parser", # filename
|
||||
"DeepSeekV3ToolParser", # class_name
|
||||
),
|
||||
"deepseek_v31": (
|
||||
"deepseekv31_tool_parser",
|
||||
"DeepSeekV31ToolParser",
|
||||
),
|
||||
"ernie45": (
|
||||
"ernie45_tool_parser",
|
||||
"Ernie45ToolParser",
|
||||
),
|
||||
"glm45": (
|
||||
"glm4_moe_tool_parser",
|
||||
"Glm4MoeModelToolParser",
|
||||
),
|
||||
"granite-20b-fc": (
|
||||
"granite_20b_fc_tool_parser",
|
||||
"Granite20bFCToolParser",
|
||||
),
|
||||
"granite": (
|
||||
"granite_tool_parser",
|
||||
"GraniteToolParser",
|
||||
),
|
||||
"hermes": (
|
||||
"hermes_tool_parser",
|
||||
"Hermes2ProToolParser",
|
||||
),
|
||||
"hunyuan_a13b": (
|
||||
"hunyuan_a13b_tool_parser",
|
||||
"HunyuanA13BToolParser",
|
||||
),
|
||||
"internlm": (
|
||||
"internlm2_tool_parser",
|
||||
"Internlm2ToolParser",
|
||||
),
|
||||
"jamba": (
|
||||
"jamba_tool_parser",
|
||||
"JambaToolParser",
|
||||
),
|
||||
"kimi_k2": (
|
||||
"kimi_k2_tool_parser",
|
||||
"KimiK2ToolParser",
|
||||
),
|
||||
"llama3_json": (
|
||||
"llama_tool_parser",
|
||||
"Llama3JsonToolParser",
|
||||
),
|
||||
"llama4_json": (
|
||||
"llama_tool_parser",
|
||||
"Llama3JsonToolParser",
|
||||
),
|
||||
"llama4_pythonic": (
|
||||
"llama4_pythonic_tool_parser",
|
||||
"Llama4PythonicToolParser",
|
||||
),
|
||||
"longcat": (
|
||||
"longcat_tool_parser",
|
||||
"LongcatFlashToolParser",
|
||||
),
|
||||
"minimax_m2": (
|
||||
"minimax_m2_tool_parser",
|
||||
"MinimaxM2ToolParser",
|
||||
),
|
||||
"minimax": (
|
||||
"minimax_tool_parser",
|
||||
"MinimaxToolParser",
|
||||
),
|
||||
"mistral": (
|
||||
"mistral_tool_parser",
|
||||
"MistralToolParser",
|
||||
),
|
||||
"olmo3": (
|
||||
"olmo3_tool_parser",
|
||||
"Olmo3PythonicToolParser",
|
||||
),
|
||||
"openai": (
|
||||
"openai_tool_parser",
|
||||
"OpenAIToolParser",
|
||||
),
|
||||
"phi4_mini_json": (
|
||||
"phi4mini_tool_parser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
),
|
||||
"pythonic": (
|
||||
"pythonic_tool_parser",
|
||||
"PythonicToolParser",
|
||||
),
|
||||
"qwen3_coder": (
|
||||
"qwen3coder_tool_parser",
|
||||
"Qwen3CoderToolParser",
|
||||
),
|
||||
"qwen3_xml": (
|
||||
"qwen3xml_tool_parser",
|
||||
"Qwen3XMLToolParser",
|
||||
),
|
||||
"seed_oss": (
|
||||
"seed_oss_tool_parser",
|
||||
"SeedOssToolParser",
|
||||
),
|
||||
"step3": (
|
||||
"step3_tool_parser",
|
||||
"Step3ToolParser",
|
||||
),
|
||||
"xlam": (
|
||||
"xlam_tool_parser",
|
||||
"xLAMToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_lazy_tool_parsers():
|
||||
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
|
||||
module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}"
|
||||
ToolParserManager.register_lazy_module(name, module_path, class_name)
|
||||
|
||||
|
||||
register_lazy_tool_parsers()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user