This commit is contained in:
root
2026-04-09 11:19:36 +08:00
parent 809cecae09
commit 8082d5f4b2
2579 changed files with 3675 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pydantic models for Anthropic API protocol"""
import time
from typing import Any, Literal, Optional
from pydantic import BaseModel, field_validator
class AnthropicError(BaseModel):
"""Error structure for Anthropic API"""
type: str
message: str
class AnthropicErrorResponse(BaseModel):
"""Error response structure for Anthropic API"""
type: Literal["error"] = "error"
error: AnthropicError
class AnthropicUsage(BaseModel):
"""Token usage information"""
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int | None = None
cache_read_input_tokens: int | None = None
class AnthropicContentBlock(BaseModel):
"""Content block in message"""
type: Literal["text", "image", "tool_use", "tool_result"]
text: str | None = None
# For image content
source: dict[str, Any] | None = None
# For tool use/result
id: str | None = None
name: str | None = None
input: dict[str, Any] | None = None
content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None
class AnthropicMessage(BaseModel):
"""Message structure"""
role: Literal["user", "assistant"]
content: str | list[AnthropicContentBlock]
class AnthropicTool(BaseModel):
"""Tool definition"""
name: str
description: str | None = None
input_schema: dict[str, Any]
@field_validator("input_schema")
@classmethod
def validate_input_schema(cls, v):
if not isinstance(v, dict):
raise ValueError("input_schema must be a dictionary")
if "type" not in v:
v["type"] = "object" # Default to object type
return v
class AnthropicToolChoice(BaseModel):
"""Tool Choice definition"""
type: Literal["auto", "any", "tool"]
name: str | None = None
class AnthropicMessagesRequest(BaseModel):
"""Anthropic Messages API request"""
model: str
messages: list[AnthropicMessage]
max_tokens: int
metadata: dict[str, Any] | None = None
stop_sequences: list[str] | None = None
stream: bool | None = False
system: str | list[AnthropicContentBlock] | None = None
temperature: float | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
top_k: int | None = None
top_p: float | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
@field_validator("max_tokens")
@classmethod
def validate_max_tokens(cls, v):
if v <= 0:
raise ValueError("max_tokens must be positive")
return v
class AnthropicDelta(BaseModel):
"""Delta for streaming responses"""
type: Literal["text_delta", "input_json_delta"] | None = None
text: str | None = None
partial_json: str | None = None
# Message delta
stop_reason: (
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
) = None
stop_sequence: str | None = None
class AnthropicStreamEvent(BaseModel):
"""Streaming event"""
type: Literal[
"message_start",
"message_delta",
"message_stop",
"content_block_start",
"content_block_delta",
"content_block_stop",
"ping",
"error",
]
message: Optional["AnthropicMessagesResponse"] = None
delta: AnthropicDelta | None = None
content_block: AnthropicContentBlock | None = None
index: int | None = None
error: AnthropicError | None = None
usage: AnthropicUsage | None = None
class AnthropicMessagesResponse(BaseModel):
"""Anthropic Messages API response"""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[AnthropicContentBlock]
model: str
stop_reason: (
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
) = None
stop_sequence: str | None = None
usage: AnthropicUsage | None = None
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"

View File

@@ -0,0 +1,460 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py
"""Anthropic Messages API serving handler"""
import json
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicContentBlock,
AnthropicDelta,
AnthropicError,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicStreamEvent,
AnthropicUsage,
)
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionToolsParam,
ErrorResponse,
StreamOptions,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
logger = logging.getLogger(__name__)
def wrap_data_with_event(data: str, event: str):
return f"event: {event}\ndata: {data}\n\n"
class AnthropicServingMessages(OpenAIServingChat):
"""Handler for Anthropic Messages API requests"""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
response_role: str,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "",
enable_auto_tools: bool = False,
tool_parser: str | None = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
response_role=response_role,
request_logger=request_logger,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
return_tokens_as_token_ids=return_tokens_as_token_ids,
reasoning_parser=reasoning_parser,
enable_auto_tools=enable_auto_tools,
tool_parser=tool_parser,
enable_prompt_tokens_details=enable_prompt_tokens_details,
enable_force_include_usage=enable_force_include_usage,
)
self.stop_reason_map = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
}
def _convert_anthropic_to_openai_request(
self, anthropic_request: AnthropicMessagesRequest
) -> ChatCompletionRequest:
"""Convert Anthropic message format to OpenAI format"""
openai_messages = []
# Add system message if provided
if anthropic_request.system:
if isinstance(anthropic_request.system, str):
openai_messages.append(
{"role": "system", "content": anthropic_request.system}
)
else:
system_prompt = ""
for block in anthropic_request.system:
if block.type == "text" and block.text:
system_prompt += block.text
openai_messages.append({"role": "system", "content": system_prompt})
for msg in anthropic_request.messages:
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
# Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.id or "",
"content": str(block.content)
if block.content
else "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls:
continue
openai_messages.append(openai_msg)
req = ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
max_tokens=anthropic_request.max_tokens,
max_completion_tokens=anthropic_request.max_tokens,
stop=anthropic_request.stop_sequences,
temperature=anthropic_request.temperature,
top_p=anthropic_request.top_p,
top_k=anthropic_request.top_k,
)
if anthropic_request.stream:
req.stream = anthropic_request.stream
req.stream_options = StreamOptions.validate({"include_usage": True})
if anthropic_request.tool_choice is None:
req.tool_choice = None
elif anthropic_request.tool_choice.type == "auto":
req.tool_choice = "auto"
elif anthropic_request.tool_choice.type == "any":
req.tool_choice = "required"
elif anthropic_request.tool_choice.type == "tool":
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
{
"type": "function",
"function": {"name": anthropic_request.tool_choice.name},
}
)
tools = []
if anthropic_request.tools is None:
return req
for tool in anthropic_request.tools:
tools.append(
ChatCompletionToolsParam.model_validate(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
)
)
if req.tool_choice is None:
req.tool_choice = "auto"
req.tools = tools
return req
async def create_messages(
self,
request: AnthropicMessagesRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | AnthropicMessagesResponse | ErrorResponse:
"""
Messages API similar to Anthropic's API.
See https://docs.anthropic.com/en/api/messages
for the API specification. This API mimics the Anthropic messages API.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Received messages request %s", request.model_dump_json())
chat_req = self._convert_anthropic_to_openai_request(request)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Convert to OpenAI request %s", chat_req.model_dump_json())
generator = await self.create_chat_completion(chat_req, raw_request)
if isinstance(generator, ErrorResponse):
return generator
elif isinstance(generator, ChatCompletionResponse):
return self.messages_full_converter(generator)
return self.message_stream_converter(generator)
def messages_full_converter(
self,
generator: ChatCompletionResponse,
) -> AnthropicMessagesResponse:
result = AnthropicMessagesResponse(
id=generator.id,
content=[],
model=generator.model,
usage=AnthropicUsage(
input_tokens=generator.usage.prompt_tokens,
output_tokens=generator.usage.completion_tokens,
),
)
if generator.choices[0].finish_reason == "stop":
result.stop_reason = "end_turn"
elif generator.choices[0].finish_reason == "length":
result.stop_reason = "max_tokens"
elif generator.choices[0].finish_reason == "tool_calls":
result.stop_reason = "tool_use"
content: list[AnthropicContentBlock] = [
AnthropicContentBlock(
type="text",
text=generator.choices[0].message.content
if generator.choices[0].message.content
else "",
)
]
for tool_call in generator.choices[0].message.tool_calls:
anthropic_tool_call = AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=json.loads(tool_call.function.arguments),
)
content += [anthropic_tool_call]
result.content = content
return result
async def message_stream_converter(
self,
generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
try:
first_item = True
finish_reason = None
content_block_index = 0
content_block_started = False
async for item in generator:
if item.startswith("data:"):
data_str = item[5:].strip().rstrip("\n")
if data_str == "[DONE]":
stop_message = AnthropicStreamEvent(
type="message_stop",
)
data = stop_message.model_dump_json(
exclude_unset=True, exclude_none=True
)
yield wrap_data_with_event(data, "message_stop")
yield "data: [DONE]\n\n"
else:
origin_chunk = ChatCompletionStreamResponse.model_validate_json(
data_str
)
if first_item:
chunk = AnthropicStreamEvent(
type="message_start",
message=AnthropicMessagesResponse(
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
),
)
first_item = False
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "message_start")
continue
# last chunk including usage info
if len(origin_chunk.choices) == 0:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop")
stop_reason = self.stop_reason_map.get(
finish_reason or "stop"
)
chunk = AnthropicStreamEvent(
type="message_delta",
delta=AnthropicDelta(stop_reason=stop_reason),
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
else 0,
output_tokens=origin_chunk.usage.completion_tokens
if origin_chunk.usage
else 0,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "message_delta")
continue
if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason
continue
# content
if origin_chunk.choices[0].delta.content is not None:
if not content_block_started:
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="text", text=""
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
if origin_chunk.choices[0].delta.content == "":
continue
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="text_delta",
text=origin_chunk.choices[0].delta.content,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
continue
# tool calls
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
if tool_call.id is not None:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(
exclude_unset=True
)
yield wrap_data_with_event(
data, "content_block_stop"
)
content_block_started = False
content_block_index += 1
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name
if tool_call.function
else None,
input={},
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
else:
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments
if tool_call.function
else None,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
continue
else:
error_response = AnthropicStreamEvent(
type="error",
error=AnthropicError(
type="internal_error",
message="Invalid data format received",
),
)
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"
except Exception as e:
logger.exception("Error in message stream converter.")
error_response = AnthropicStreamEvent(
type="error",
error=AnthropicError(type="internal_error", message=str(e)),
)
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"

View File

@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import json
import ssl
from argparse import Namespace
from collections.abc import AsyncGenerator
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: AsyncLLMEngine | None = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (
llm_engine
if llm_engine is not None
else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER
)
)
app.state.engine_client = engine
return app
async def run_server(
args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
set_ulimit()
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=parser.check_port, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
)
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change",
)
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)",
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy",
)
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkServingSubcommand",
"BenchmarkSweepSubcommand",
"BenchmarkThroughputSubcommand",
]

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkSubcommandBase(CLISubcommand):
"""The base class of subcommands for `vllm bench`."""
help: str
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
"""Add the CLI arguments to the parser."""
raise NotImplementedError
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Run the benchmark.
Args:
args: The arguments to the command.
"""
raise NotImplementedError

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
"""The `latency` subcommand for `vllm bench`."""
name = "latency"
help = "Benchmark the latency of a single batch of requests."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = argparse.ArgumentParser
class BenchmarkSubcommand(CLISubcommand):
"""The `bench` subcommand for the vLLM CLI."""
name = "bench"
help = "vLLM bench subcommand."
@staticmethod
def cmd(args: argparse.Namespace) -> None:
args.dispatch_function(args)
def validate(self, args: argparse.Namespace) -> None:
pass
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
bench_parser = subparsers.add_parser(
self.name,
description=self.help,
usage=f"vllm {self.name} <bench_type> [options]",
)
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
cmd_subparser = bench_subparsers.add_parser(
cmd_cls.name,
help=cmd_cls.help,
description=cmd_cls.help,
usage=f"vllm {self.name} {cmd_cls.name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd)
cmd_cls.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"{self.name} {cmd_cls.name}"
)
return bench_parser
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkSubcommand()]

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.serve import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
"""The `serve` subcommand for `vllm bench`."""
name = "serve"
help = "Benchmark the online serving throughput."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.sweep.cli import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkSweepSubcommand(BenchmarkSubcommandBase):
"""The `sweep` subcommand for `vllm bench`."""
name = "sweep"
help = "Benchmark for a parameter sweep."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
"""The `throughput` subcommand for `vllm bench`."""
name = "throughput"
help = "Benchmark offline inference throughput."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import typing
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = argparse.ArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `collect-env` subcommand for the vLLM CLI."""
name = "collect-env"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Collect information about the environment."""
collect_env_main()
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
return subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env",
)
def cmd_init() -> list[CLISubcommand]:
return [CollectEnvSubcommand()]

View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The CLI entrypoints of vLLM
Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage."""
import importlib.metadata
import sys
from vllm.logger import init_logger
logger = init_logger(__name__)
def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils.argparse_utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
cli_env_setup()
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
if len(sys.argv) > 1 and sys.argv[1] == "bench":
logger.debug(
"Bench command detected, must ensure current platform is not "
"UnspecifiedPlatform to avoid device type inference error"
)
from vllm import platforms
if platforms.current_platform.is_unspecified():
from vllm.platforms.cpu import CpuPlatform
platforms.current_platform = CpuPlatform()
logger.info(
"Unspecified platform detected, switching to CPU Platform instead."
)
parser = FlexibleArgumentParser(
description="vLLM CLI",
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
)
parser.add_argument(
"-v",
"--version",
action="version",
version=importlib.metadata.version("vllm"),
)
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd_module in CMD_MODULES:
new_cmds = cmd_module.cmd_init()
for cmd in new_cmds:
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
args = parser.parse_args()
if args.subparser in cmds:
cmds[args.subparser].validate(args)
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,256 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import os
import signal
import sys
from typing import TYPE_CHECKING
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.cli.types import CLISubcommand
if TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = argparse.ArgumentParser
def _register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
_register_signal_handlers()
base_url = args.url
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
openai_client = OpenAI(api_key=api_key, base_url=base_url)
if args.model_name:
model_name = args.model_name
else:
available_models = openai_client.models.list()
model_name = available_models.data[0].id
print(f"Using model: {model_name}")
return model_name, openai_client
def _print_chat_stream(stream) -> str:
output = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
output += delta.content
print(delta.content, end="", flush=True)
print()
return output
def _print_completion_stream(stream) -> str:
output = ""
for chunk in stream:
text = chunk.choices[0].text
if text is not None:
output += text
print(text, end="", flush=True)
print()
return output
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
break
conversation.append({"role": "user", "content": input_message})
stream = client.chat.completions.create(
model=model_name, messages=conversation, stream=True
)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--url",
type=str,
default="http://localhost:8000/v1",
help="url of the running OpenAI-Compatible RESTful API server",
)
parser.add_argument(
"--model-name",
type=str,
default=None,
help=(
"The model name used in prompt completion, default to "
"the first model in list models API call."
),
)
parser.add_argument(
"--api-key",
type=str,
default=None,
help=(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
),
)
return parser
class ChatCommand(CLISubcommand):
"""The `chat` subcommand for the vLLM CLI."""
name = "chat"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
system_prompt = args.system_prompt
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
if args.quick:
conversation.append({"role": "user", "content": args.quick})
stream = client.chat.completions.create(
model=model_name, messages=conversation, stream=True
)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
return
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
break
conversation.append({"role": "user", "content": input_message})
stream = client.chat.completions.create(
model=model_name, messages=conversation, stream=True
)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Add CLI arguments for the chat command."""
_add_query_options(parser)
parser.add_argument(
"--system-prompt",
type=str,
default=None,
help=(
"The system prompt to be added to the chat template, "
"used for models that support system prompts."
),
)
parser.add_argument(
"-q",
"--quick",
type=str,
metavar="MESSAGE",
help=("Send a single prompt as MESSAGE and print the response, then exit."),
)
return parser
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]",
)
return ChatCommand.add_cli_args(parser)
class CompleteCommand(CLISubcommand):
"""The `complete` subcommand for the vLLM CLI."""
name = "complete"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
kwargs = {
"model": model_name,
"stream": True,
}
if args.max_tokens:
kwargs["max_tokens"] = args.max_tokens
if args.quick:
stream = client.completions.create(prompt=args.quick, **kwargs)
_print_completion_stream(stream)
return
print("Please enter prompt to complete:")
while True:
try:
input_prompt = input("> ")
except EOFError:
break
stream = client.completions.create(prompt=input_prompt, **kwargs)
_print_completion_stream(stream)
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Add CLI arguments for the complete command."""
_add_query_options(parser)
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum number of tokens to generate per output sequence.",
)
parser.add_argument(
"-q",
"--quick",
type=str,
metavar="PROMPT",
help="Send a single prompt and print the completion output, then exit.",
)
return parser
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
parser = subparsers.add_parser(
"complete",
help=(
"Generate text completions based on the given prompt "
"via the running API server."
),
description=(
"Generate text completions based on the given prompt "
"via the running API server."
),
usage="vllm complete [options]",
)
return CompleteCommand.add_cli_args(parser)
def cmd_init() -> list[CLISubcommand]:
return [ChatCommand(), CompleteCommand()]

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import asyncio
import importlib.metadata
import typing
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.logger import init_logger
if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = argparse.ArgumentParser
logger = init_logger(__name__)
class RunBatchSubcommand(CLISubcommand):
"""The `run-batch` subcommand for vLLM CLI."""
name = "run-batch"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
from vllm.entrypoints.openai.run_batch import main as run_batch_main
logger.info(
"vLLM batch processing API version %s", importlib.metadata.version("vllm")
)
logger.info("args: %s", args)
# Start the Prometheus metrics server.
# LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
from prometheus_client import start_http_server
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(run_batch_main(args))
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
from vllm.entrypoints.openai.run_batch import make_arg_parser
run_batch_parser = subparsers.add_parser(
self.name,
help="Run batch prompts and write results to file.",
description=(
"Run batch prompts using vLLM's OpenAI-compatible API.\n"
"Supports local or HTTP input/output files."
),
usage="vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
)
run_batch_parser = make_arg_parser(run_batch_parser)
run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
return run_batch_parser
def cmd_init() -> list[CLISubcommand]:
return [RunBatchSubcommand()]

View File

@@ -0,0 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import signal
import uvloop
import vllm
import vllm.envs as envs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import (
run_server,
run_server_worker,
setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor import Executor
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
logger = init_logger(__name__)
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.
Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
Use `--help=all` to show all available flags at once.
"""
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI."""
name = "serve"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# If model is specified in CLI (as positional arg), it takes precedence
if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
self.name, description=DESCRIPTION, usage="vllm serve [model_tag] [options]"
)
serve_parser = make_arg_parser(serve_parser)
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
return serve_parser
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
def run_headless(args: argparse.Namespace):
if args.api_server_count > 1:
raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(
usage_context=usage_context, headless=True
)
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
shutdown_requested = False
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
nonlocal shutdown_requested
logger.debug("Received %d signal.", signum)
if not shutdown_requested:
shutdown_requested = True
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if parallel_config.node_rank_within_dp > 0:
from vllm.version import __version__ as VLLM_VERSION
# Run headless workers (for multi-node PP/TP).
host = parallel_config.master_addr
head_node_address = f"{host}:{parallel_config.master_port}"
logger.info(
"Launching vLLM (v%s) headless multiproc executor, "
"with head node address %s for torch.distributed process group.",
VLLM_VERSION,
head_node_address,
)
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
executor.start_worker_monitor(inline=True)
return
host = parallel_config.data_parallel_master_ip
port = parallel_config.data_parallel_rpc_port
handshake_address = get_tcp_uri(host, port)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.",
local_engine_count,
handshake_address,
)
# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=vllm_config.parallel_config.data_parallel_rank,
local_start_index=0,
vllm_config=vllm_config,
local_client=False,
handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()
def run_multi_api_server(args: argparse.Namespace):
assert not args.headless
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1:
setup_multiprocess_prometheus()
listen_address, sock = setup_server(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
engine_args._api_process_count = num_api_servers
engine_args._api_process_rank = -1
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
)
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
parallel_config = vllm_config.parallel_config
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
api_server_manager: APIServerProcessManager | None = None
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=addresses.inputs,
output_addresses=addresses.outputs,
stats_update_address=coordinator.get_stats_publish_address()
if coordinator
else None,
)
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
# start of the API servers until the local engine is started
# (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
# via the handshake with the local engine.
if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb):
# Start API servers using the manager.
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Start API servers now if they weren't already started.
if api_server_manager is None:
api_server_manager_kwargs["stats_update_address"] = (
addresses.frontend_stats_publish_address
)
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator,
)
def run_api_server_worker_proc(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Entrypoint for individual API server worker processes."""
client_config = client_config or {}
server_index = client_config.get("client_index", 0)
# Set process title and add process-specific prefix to stdout and stderr.
set_process_title("APIServer", str(server_index))
decorate_logs()
uvloop.run(
run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
)

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import typing
if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = argparse.ArgumentParser
class CLISubcommand:
"""Base class for CLI argument handlers."""
name: str
@staticmethod
def cmd(args: argparse.Namespace) -> None:
raise NotImplementedError("Subclasses should implement this method")
def validate(self, args: argparse.Namespace) -> None:
# No validation by default
pass
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
raise NotImplementedError("Subclasses should implement this method")

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared constants for vLLM entrypoints.
"""
# HTTP header limits for h11 parser
# These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256

View File

@@ -0,0 +1,572 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import logging
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
if TYPE_CHECKING:
from mcp.client import ClientSession
logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}"
)
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnMetrics:
"""Tracks token and toolcall details for a single conversation turn."""
def __init__(
self,
input_tokens=0,
output_tokens=0,
cached_input_tokens=0,
tool_output_tokens=0,
):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.cached_input_tokens = cached_input_tokens
self.tool_output_tokens = tool_output_tokens
def reset(self):
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
self.cached_input_tokens = 0
self.tool_output_tokens = 0
def copy(self):
"""Create a copy of this turn's token counts."""
return TurnMetrics(
self.input_tokens,
self.output_tokens,
self.cached_input_tokens,
self.tool_output_tokens,
)
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output: RequestOutput) -> None:
pass
@abstractmethod
def append_tool_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
@abstractmethod
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
) -> None:
pass
@abstractmethod
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
def _create_json_parse_error_messages(
last_msg: Message, e: json.JSONDecodeError
) -> list[Message]:
"""
Creates an error message when json parse failed.
"""
error_msg = (
f"Error parsing tool arguments as JSON: {str(e)}. "
"Please ensure the tool call arguments are valid JSON and try again."
)
content = TextContent(text=error_msg)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# todo num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
# not implemented yet for SimpleContext
self.all_turn_metrics = []
def append_output(self, output) -> None:
self.last_output = output
if not isinstance(output, RequestOutput):
raise ValueError("SimpleContext only supports RequestOutput.")
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
) -> None:
pass
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
available_tools: list[str],
):
self._messages = messages
self.finish_reason: str | None = None
self.available_tools = available_tools
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
# Turn tracking - replaces multiple individual tracking variables
self.current_turn_metrics = TurnMetrics()
# Track metrics for all turns
self.all_turn_metrics: list[TurnMetrics] = []
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
def _update_num_reasoning_tokens(self):
# Count all analysis and commentary channels as reasoning tokens
if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1
def append_output(self, output: RequestOutput) -> None:
output_token_ids = output.outputs[0].token_ids
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
self._update_decode_token_usage(output)
# Append current turn to all turn list for next turn's calculations
self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# append_output is called only once before tool calling
# in non-streaming case
# so we can append all the parser messages to _messages
output_msgs = self.parser.messages
# The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason
self._messages.extend(output_msgs)
def append_tool_output(self, output: list[Message]) -> None:
output_msgs = output
self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error("RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn_metrics.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
previous_turn = self.all_turn_metrics[-1]
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (
self.current_turn_metrics.input_tokens
- previous_turn.input_tokens
- previous_turn.output_tokens
)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens,
self.current_turn_metrics.input_tokens,
previous_turn.input_tokens,
previous_turn.output_tokens,
)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
# Update cached tokens
num_cached_token = output.num_cached_tokens
if num_cached_token is not None:
self.num_cached_tokens += num_cached_token
self.current_turn_metrics.cached_input_tokens = num_cached_token
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn_metrics.output_tokens += updated_output_token_count
return updated_output_token_count
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (
recipient.startswith("browser.")
or recipient.startswith("python")
or recipient.startswith("container.")
)
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self._tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self._tool_sessions["python"], last_msg
)
elif recipient.startswith("container."):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg
)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]
async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def call_container_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True
@property
def messages(self) -> list:
return self._messages
def append_output(self, output: RequestOutput) -> None:
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_prefill_token_usage(output)
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
self._messages.extend(
self.parser.messages[len(self._messages) - self.num_init_messages :]
)
def append_tool_output(self, output: list[Message]) -> None:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant`,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import model_hosting_container_standards.sagemaker as sagemaker_standards
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import models, validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
def register_dynamic_lora_routes(router: APIRouter):
@sagemaker_standards.register_load_adapter_handler(
request_shape={
"lora_name": "body.name",
"lora_path": "body.src",
},
)
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
handler: OpenAIServingModels = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
return Response(status_code=200, content=response)
@sagemaker_standards.register_unload_adapter_handler(
request_shape={
"lora_name": "path_params.adapter_name",
}
)
@router.post(
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
)
async def unload_lora_adapter(
request: UnloadLoRAAdapterRequest, raw_request: Request
):
handler: OpenAIServingModels = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
return Response(status_code=200, content=response)
return router

View File

@@ -0,0 +1,535 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime
import json
from collections.abc import Iterable, Sequence
from typing import Literal
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_web_search import (
ActionFind,
ActionOpenPage,
ActionSearch,
ResponseFunctionWebSearch,
)
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai_harmony import (
Author,
ChannelConfig,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
from openai_harmony import Message as OpenAIHarmonyMessage
from openai_harmony import Role as OpenAIHarmonyRole
from vllm import envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionToolsParam,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.utils import random_uuid
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
MCP_BUILTIN_TOOLS: set[str] = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: set[str]) -> bool:
"""
Checks if the given tool types are custom tools
(i.e. any tool other than MCP buildin tools)
"""
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: str | None = None,
reasoning_effort: Literal["high", "medium", "low"] | None = None,
start_date: str | None = None,
browser_description: str | None = None,
python_description: str | None = None,
container_description: str | None = None,
instructions: str | None = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
current_identity = sys_msg_content.model_identity
new_identity = (
f"{current_identity}\n{instructions}" if current_identity else instructions
)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]
)
if start_date is None:
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
if not with_custom_tools:
channel_config = sys_msg_content.channel_config
invalid_channel = "commentary"
new_config = ChannelConfig.require_channels(
[c for c in channel_config.valid_channels if c != invalid_channel]
)
sys_msg_content = sys_msg_content.with_channel_config(new_config)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def create_tool_definition(tool: ChatCompletionToolsParam | Tool):
if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: str | None = None,
tools: list[Tool | ChatCompletionToolsParam] | None = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools: list[Tool | ChatCompletionToolsParam] = []
for tool in tools:
if tool.type in (
"web_search_preview",
"code_interpreter",
"container",
"mcp",
):
# These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
create_tool_definition(tool) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions
)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_response_input(
response_msg: ResponseInputOutputItem,
prev_responses: list[ResponseOutputItem | ResponseReasoningItem],
) -> Message:
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: ResponseFunctionToolCall | None = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
and prev_response.call_id == call_id
):
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"],
)
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
if isinstance(content, list):
# Handle array format for tool message content
# by concatenating all text parts.
content = "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents)
return [msg]
def construct_harmony_previous_input_messages(
request: ResponsesRequest,
) -> list[OpenAIHarmonyMessage]:
messages: list[OpenAIHarmonyMessage] = []
if request.previous_input_messages:
for message in request.previous_input_messages:
# Handle both OpenAIHarmonyMessage objects and dictionary inputs
if isinstance(message, OpenAIHarmonyMessage):
message_role = message.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(message)
else:
harmony_messages = parse_input_to_harmony_message(message)
for harmony_msg in harmony_messages:
message_role = harmony_msg.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(harmony_msg)
return messages
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT
)
return token_ids
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
# We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
# env variable since if it is not set, we are certain the json is valid
# The use of Actions for web search will be removed entirely in
# the future, so this is only necessary temporarily
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
# If the content is not valid JSON, then it was
# caught and retried by vLLM, which means we
# need to make note of that so the user is aware
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if recipient is not None and recipient.startswith("functions."):
function_name = recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
elif recipient is not None and (
recipient.startswith("python")
or recipient.startswith("browser")
or recipient.startswith("container")
):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if current_recipient is not None and current_recipient.startswith("browser."):
return []
if parser.current_channel == "analysis":
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
return [reasoning_item]
elif parser.current_channel == "final":
output_text = ResponseOutputText(
text=parser.current_content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
# if the parser still has messages (ie if the generator got cut
# abruptly), this should be incomplete
status="incomplete",
type="message",
)
return [text_item]
return []
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]:
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
reasoning = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
reasoning = output_msgs[0].content[0].text
final_content = parser.current_content
else:
reasoning_msg = output_msgs[:-1]
final_msg = output_msgs[-1]
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
final_content = final_msg.content[0].text
return reasoning, final_content, is_tool_call

View File

@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import signal
import socket
from http import HTTPStatus
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.constants import (
H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
)
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils.network_utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
async def serve_http(
app: FastAPI,
sock: socket.socket | None,
enable_ssl_refresh: bool = False,
**uvicorn_kwargs: Any,
):
"""
Start a FastAPI app using Uvicorn, with support for custom Uvicorn config
options. Supports http header limits via h11_max_incomplete_event_size and
h11_max_header_count.
"""
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)
if methods is None or path is None:
continue
logger.info("Route: %s, Methods: %s", path, ", ".join(methods))
# Extract header limit options if present
h11_max_incomplete_event_size = uvicorn_kwargs.pop(
"h11_max_incomplete_event_size", None
)
h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None)
# Set safe defaults if not provided
if h11_max_incomplete_event_size is None:
h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
if h11_max_header_count is None:
h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT
config = uvicorn.Config(app, **uvicorn_kwargs)
# Set header limits
config.h11_max_incomplete_event_size = h11_max_incomplete_event_size
config.h11_max_header_count = h11_max_header_count
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task(server.serve(sockets=[sock] if sock else None))
ssl_cert_refresher = (
None
if not enable_ssl_refresh
else SSLCertRefresher(
ssl_context=config.ssl,
key_path=config.ssl_keyfile,
cert_path=config.ssl_certfile,
ca_path=config.ssl_ca_certs,
)
)
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
async def dummy_shutdown() -> None:
pass
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
port = uvicorn_kwargs["port"]
process = find_process_using_port(port)
if process is not None:
logger.warning(
"port %s is used by process %s launched with command:\n%s",
port,
process,
" ".join(process.cmdline()),
)
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

1768
vllm_old/entrypoints/llm.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import torch
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: int | None) -> None:
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: str | None,
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.debug(
"Request %s details: prompt: %r, "
"prompt_token_ids: %s, "
"prompt_embeds shape: %s.",
request_id,
prompt,
prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
)
logger.info(
"Received request %s: params: %s, lora_request: %s.",
request_id,
params,
lora_request,
)
def log_outputs(
self,
request_id: str,
outputs: str,
output_token_ids: Sequence[int] | None,
finish_reason: str | None = None,
is_streaming: bool = False,
delta: bool = False,
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if outputs is not None:
outputs = outputs[:max_log_len]
if output_token_ids is not None:
# Convert to list and apply truncation
output_token_ids = list(output_token_ids)[:max_log_len]
stream_info = ""
if is_streaming:
stream_info = " (streaming delta)" if delta else " (streaming complete)"
logger.info(
"Generated response %s%s: output: %r, "
"output_token_ids: %s, finish_reason: %s",
request_id,
stream_info,
outputs,
output_token_ids,
finish_reason,
)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from collections.abc import Sequence
from dataclasses import field
from typing import Literal
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
validate_chat_template,
)
from vllm.entrypoints.constants import (
H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str | Sequence[str] | None,
option_string: str | None = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: list[LoRAModulePath] = []
for item in values:
if item in [None, ""]: # Skip if item is None or empty string
continue
if "=" in item and "," not in item: # Old format: name=path
name, path = item.split("=")
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
@config
@dataclass
class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server."""
host: str | None = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: str | None = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal[
"debug", "info", "warning", "error", "critical", "trace"
] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: list[str] | None = None
"""If provided, the server will require one of these keys to be presented in
the header."""
lora_modules: list[LoRAModulePath] | None = None
"""LoRA modules configurations in either 'name=path' format or JSON format
or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`"""
chat_template: str | None = None
"""The file path to the chat template, or the template in single-line form
for the specified model."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: str | None = None
"""The file path to the SSL key file."""
ssl_certfile: str | None = None
"""The file path to the SSL cert file."""
ssl_ca_certs: str | None = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
root_path: str | None = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
return_tokens_as_token_ids: bool = False
"""When `--max-logprobs` is specified, represents single tokens as
strings of the form 'token_id:{token_id}' so that tokens that are not
JSON-encodable can be identified."""
disable_frontend_multiprocessing: bool = False
"""If specified, will run the OpenAI frontend server in the same process as
the model serving engine."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
enable_auto_tool_choice: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use."""
exclude_tools_when_tool_choice_none: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser: str | None = None
"""Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format.
Required for `--enable-auto-tool-choice`. You can choose any option from
the built-in parsers or register a plugin via `--tool-parser-plugin`."""
tool_parser_plugin: str = ""
"""Special the tool parser plugin write to parse the model-generated tool
into OpenAI API format, the name register in this plugin can be used in
`--tool-call-parser`."""
tool_server: str | None = None
"""Comma-separated list of host:port pairs (IPv4, IPv6, or hostname).
Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo
purpose."""
log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH
"""Path to logging config JSON file for both vllm and uvicorn"""
max_log_len: int | None = None
"""Max number of prompt characters or prompt ID numbers being printed in
log. The default of None means unlimited."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
enable_prompt_tokens_details: bool = False
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If True, log model outputs (generations).
Requires --enable-log-requests."""
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses"""
tokens_only: bool = False
"""
If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup.
"""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
from vllm.engine.arg_utils import get_kwargs
frontend_kwargs = get_kwargs(FrontendArgs)
# Special case: allowed_origins, allowed_methods, allowed_headers all
# need json.loads type
# Should also remove nargs
frontend_kwargs["allowed_origins"]["type"] = json.loads
frontend_kwargs["allowed_methods"]["type"] = json.loads
frontend_kwargs["allowed_headers"]["type"] = json.loads
del frontend_kwargs["allowed_origins"]["nargs"]
del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str
if "nargs" in frontend_kwargs["middleware"]:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
)
frontend_group = parser.add_argument_group(
title="Frontend",
description=FrontendArgs.__doc__,
)
for key, value in frontend_kwargs.items():
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
return parser
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Create the CLI argument parser used by the OpenAI API server.
We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to
register all arguments instead of manually enumerating them here. This
avoids code duplication and keeps the argument definitions in one place.
"""
parser.add_argument(
"model_tag",
type=str,
nargs="?",
help="The model tag to serve (optional if specified in config)",
)
parser.add_argument(
"--headless",
action="store_true",
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.",
)
parser.add_argument(
"--api-server-count",
"-asc",
type=int,
default=1,
help="How many API server processes to run.",
)
parser.add_argument(
"--config",
help="Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html",
)
parser = FrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser")
if args.enable_log_outputs and not args.enable_log_requests:
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server"
)
return make_arg_parser(parser_for_docs)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utility functions that create ORCA endpoint load report response headers.
"""
import json
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
logger = init_logger(__name__)
def create_orca_header(
metrics_format: str, named_metrics: list[tuple[str, float]]
) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
and adds custom metrics to named_metrics.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
and their corresponding double values.
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if metrics_format.lower() not in ["text", "json"]:
logger.warning(
"Warning: `%s` format is not supported in the ORCA response header",
format,
)
return None
header = {}
orca_report = {
"named_metrics": {
metric_name: value
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
}
}
# output example:
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
if metrics_format.lower() == "text":
native_http_header = ", ".join(
[
f"named_metrics.{metric_name}={value}"
for metric_name, value in named_metrics
if isinstance(metric_name, str) and isinstance(value, float)
]
)
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
# output example:
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
elif metrics_format.lower() == "json":
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
logger.info("Created ORCA header %s", header)
return header
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
"""
Collects current metrics from Prometheus and returns some of them
in the form of the `named_metrics` list for `create_orca_header()`.
Parameters:
- None
Returns:
- list[tuple[str, float]]: List of tuples of metric names and their values.
"""
named_metrics: list[tuple[str, float]] = []
# Map from prometheus metric names to ORCA named metrics.
prometheus_to_orca_metrics = {
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
"vllm:num_requests_waiting": "num_requests_waiting",
}
metrics = get_metrics_snapshot()
for metric in metrics:
orca_name = prometheus_to_orca_metrics.get(metric.name)
# If this metric is mapped into ORCA, then add it to the report.
# Note: Only Gauge metrics are currently supported.
if orca_name is not None and isinstance(metric, Gauge):
named_metrics.append((str(orca_name), float(metric.value)))
return named_metrics
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
"""
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
Parameters:
- metrics_format (str): The format of the header ('TEXT', 'JSON').
Returns:
- Optional[Mapping[str,str]]: A dictionary with header key as
'endpoint-load-metrics' and values as the ORCA header strings with
format prefix and data in with named_metrics in.
"""
if not metrics_format:
return None
# Get named metrics from prometheus.
named_metrics = get_named_metrics_from_prometheus()
return create_orca_header(metrics_format, named_metrics)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,547 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import tempfile
from argparse import Namespace
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from io import StringIO
import aiohttp
import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse,
ErrorResponse,
RerankResponse,
ScoreResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help="The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.",
)
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.",
)
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it "
"to the output URL.",
)
parser.add_argument(
"--response-role",
type=optional_type(str),
default="assistant",
help="The role name to return if `request.add_generation_prompt=True`.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument(
"--max-log-len",
type=int,
default=None,
help="Max number of prompt characters or prompt "
"ID numbers being printed in log."
"\n\nDefault: Unlimited",
)
parser.add_argument(
"--enable-metrics", action="store_true", help="Enable Prometheus metrics"
)
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action="store_true",
default=False,
help="If set to True, enable prompt_tokens_details in usage.",
)
parser.add_argument(
"--enable-force-include-usage",
action="store_true",
default=False,
help="If set to True, include usage on every request "
"(even when stream_options is not specified)",
)
return parser
def parse_args():
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
return make_arg_parser(parser).parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: tqdm | None = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
self._pbar = tqdm(
total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(
output_path: str, batch_outputs: list[BatchRequestOutput]
) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't affect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=1000)
) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url, data=file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
else:
async with session.put(output_url, data=data_or_file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
except Exception as e:
if attempt < max_retries:
logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(
path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
logger.info("Writing outputs to temporary local file %s", f.name)
await write_local_file(f.name, batch_outputs)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def make_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"vllm-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(
serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker,
) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=response, request_id=f"vllm-batch-{random_uuid()}"
),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}",
),
error=response,
)
else:
batch_output = make_error_request_output(
request, error_msg="Request must not be sent in stream mode"
)
tracker.completed()
return batch_output
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
async def run_batch(
engine_client: EngineClient,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
model_config = engine_client.model_config
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=None,
)
openai_serving_chat = (
OpenAIServingChat(
engine_client,
openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
openai_serving_embedding = (
OpenAIServingEmbedding(
engine_client,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
)
if "embed" in supported_tasks
else None
)
enable_serving_reranking = (
"classify" in supported_tasks
and getattr(model_config.hf_config, "num_labels", 0) == 1
)
openai_serving_scores = (
ServingScores(
engine_client,
openai_serving_models,
request_logger=request_logger,
)
if ("embed" in supported_tasks or enable_serving_reranking)
else None
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = (
openai_serving_chat.create_chat_completion
if openai_serving_chat is not None
else None
)
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Chat Completions API",
)
)
continue
response_futures.append(run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
embed_handler_fn = (
openai_serving_embedding.create_embedding
if openai_serving_embedding is not None
else None
)
if embed_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
)
)
continue
response_futures.append(run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/score"):
score_handler_fn = (
openai_serving_scores.create_score
if openai_serving_scores is not None
else None
)
if score_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Scores API",
)
)
continue
response_futures.append(run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = (
openai_serving_scores.do_rerank
if openai_serving_scores is not None
else None
)
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
)
)
continue
response_futures.append(run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
)
)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.usage.usage_lib import UsageContext
validate_run_batch_args(args)
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
await run_batch(engine_client, args)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,235 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import cast
import jinja2
import numpy as np
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData,
ClassificationRequest,
ClassificationResponse,
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
try:
ctx.tokenizer = await self.engine_client.get_tokenizer()
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
)
if ret:
return ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.tokenizer,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_renderer(ctx.tokenizer)
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request),
)
else:
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
item = ClassificationData(
index=idx,
label=label,
probs=probs,
num_classes=len(probs),
)
items.append(item)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ClassificationResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await super().handle(ctx) # type: ignore
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -0,0 +1,715 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__)
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
# set up logits processors
self.logits_processors = self.model_config.logits_processors
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response("suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response("Echo is unsupported with prompt embeds.")
if request.prompt_logprobs is not None and request.prompt_embeds is not None:
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds."
)
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (
request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search
)
# Streaming response
if stream:
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
)
# Non-streaming response
final_res_batch: list[RequestOutput | None] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = (
None
if is_embeds_prompt(engine_prompt)
else engine_prompt.get("prompt")
)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: list[int] | None = None
assert request.max_tokens is not None
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i] and request.return_token_ids:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (
not delta_text
and not delta_token_ids
and not previous_num_tokens[i]
):
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(
as_list(output.token_ids)
if request.return_token_ids
else None
),
)
],
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage_info,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except Exception as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: list[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(
prompt_token_ids if request.return_token_ids else None
),
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if (
self.enable_prompt_tokens_details
and last_final_res
and last_final_res.num_cached_tokens
):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=last_final_res.num_cached_tokens
)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = []
out_token_logprobs: list[float | None] = []
out_tokens: list[str] = []
out_top_logprobs: list[dict[str, float] | None] = []
last_token_len = 0
should_return_as_token_id = (
return_as_token_id
if return_as_token_id is not None
else self.return_tokens_as_token_ids
)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=should_return_as_token_id,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append(
{
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
): max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
}
)
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)

View File

@@ -0,0 +1,695 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast
import torch
from fastapi import Request
from fastapi.responses import Response
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
TextTokensPrompt,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pooler_config = self.model_config.pooler_config
# Avoid repeated attribute lookups
self.supports_chunked_processing = bool(
pooler_config and pooler_config.enable_chunked_processing
)
self.max_embed_len = (
pooler_config.max_embed_len
if pooler_config and pooler_config.max_embed_len
else None
)
@override
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request,
tokenizer,
ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=False,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
@override
def _build_response(
self,
ctx: ServeContext,
) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch_checked):
item = EmbeddingResponseData(
index=idx,
embedding=encode_pooling_output(
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
def encode_bytes():
body, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch_checked,
embed_dtype=embed_dtype,
endianness=endianness,
)
metadata = {
"id": ctx.request_id,
"created": ctx.created_time,
"model": ctx.model_name,
"data": items,
"usage": usage,
}
return EmbeddingBytesResponse(
body=body,
metadata=json.dumps(metadata),
)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes":
return encode_bytes()
else:
assert_never(encoding_format)
def _get_max_position_embeddings(self) -> int:
"""Get the model's effective maximum sequence length for chunking."""
return self.model_config.max_model_len
def _should_use_chunked_processing(self, request) -> bool:
"""Check if chunked processing should be used for this request."""
return (
isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
and self.supports_chunked_processing
)
async def _process_chunked_request(
self,
ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt,
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings()
# Process all chunks for MEAN aggregation
for chunk_idx, chunk_tokens in enumerate(
chunk_list(token_ids, max_pos_embeddings)
):
# Create a request ID for this chunk
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens
)
# Log the chunk
self._log_inputs(
chunk_request_id,
chunk_request_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(original_generator)
return generators
def _validate_input(
self,
request,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
"""Override to support chunked processing for embedding requests."""
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
# Check if chunked processing is enabled for pooling models
enable_chunked = self._should_use_chunked_processing(request)
# Use max_position_embeddings for chunked processing decisions
max_pos_embeddings = self._get_max_position_embeddings()
# Determine the effective max length for validation
if self.max_embed_len is not None:
# Use max_embed_len for validation instead of max_model_len
length_type = "maximum embedding input length"
max_length_value = self.max_embed_len
else:
# Fall back to max_model_len validation (original behavior)
length_type = "maximum context length"
max_length_value = self.max_model_len
validation_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input."
)
chunked_processing_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input "
"or enable chunked processing."
)
# Check if input exceeds max length
if token_num > max_length_value:
raise ValueError(
validation_error_msg.format(
length_type=length_type,
max_length_value=max_length_value,
token_num=token_num,
)
)
# Check for chunked processing
# when exceeding max_position_embeddings
if token_num > max_pos_embeddings:
if enable_chunked:
# Allow long inputs when chunked processing is enabled
logger.info(
"Input length %s exceeds max_position_embeddings "
"%s, will use chunked processing",
token_num,
max_pos_embeddings,
)
else:
raise ValueError(
chunked_processing_error_msg.format(
length_type="maximum position embeddings length",
max_length_value=max_pos_embeddings,
token_num=token_num,
)
)
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: EngineTokensPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
# If no chunked processing needed, delegate to parent class
if not use_chunked:
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try:
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
max_pos_embeddings = self._get_max_position_embeddings()
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params, trace_headers, i
)
generators.extend(chunk_generators)
continue
# Normal processing for short prompts or non-token prompts
generator = await self._create_single_prompt_generator(
ctx, engine_prompt, pooling_params, trace_headers, i
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect and aggregate batch results
with support for chunked processing.
For chunked requests, performs online aggregation to
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
# Check if we used chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
if not use_chunked:
return await super()._collect_batch(ctx=ctx)
if ctx.result_generator is None:
return self.create_error_response("Result generator not available")
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
prompt_aggregators: dict[int, dict[str, Any]] = {}
short_prompts_results: dict[int, PoolingRequestOutput] = {}
async for result_idx, result in ctx.result_generator:
if "-chunk-" in result.request_id:
# Extract prompt_idx from chunked request_id
parts = result.request_id.split("-")
try:
prompt_idx = int(parts[parts.index("prompt") + 1])
except (ValueError, IndexError):
# Fallback: extract from result_idx if parsing fails
prompt_idx = result_idx
# Initialize aggregator for this prompt if needed
if prompt_idx not in prompt_aggregators:
prompt_aggregators[prompt_idx] = {
"weighted_sum": None,
"total_weight": 0,
"chunk_count": 0,
"request_id": result.request_id.split("-chunk-")[0],
}
aggregator = prompt_aggregators[prompt_idx]
# MEAN pooling with online weighted averaging
# Ensure result is PoolingRequestOutput
# for embedding processing
if not isinstance(result, PoolingRequestOutput):
return self.create_error_response(
f"Expected PoolingRequestOutput for "
f"chunked embedding, got "
f"{type(result).__name__}"
)
# Handle both PoolingOutput and
# EmbeddingOutput types
if hasattr(result.outputs, "data"):
# PoolingOutput case
embedding_data = result.outputs.data
elif hasattr(result.outputs, "embedding"):
# EmbeddingOutput case -
# convert embedding list to tensor
embedding_data = result.outputs.embedding
else:
return self.create_error_response(
f"Unsupported output type: {type(result.outputs).__name__}"
)
if not isinstance(embedding_data, torch.Tensor):
embedding_data = torch.tensor(
embedding_data, dtype=torch.float32
)
if result.prompt_token_ids is None:
return self.create_error_response(
"prompt_token_ids cannot be None for chunked processing"
)
weight = len(result.prompt_token_ids)
weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
if aggregator["weighted_sum"] is None:
# First chunk
aggregator["weighted_sum"] = weighted_embedding
else:
# Accumulate
aggregator["weighted_sum"] += weighted_embedding
aggregator["total_weight"] += weight
aggregator["chunk_count"] += 1
else:
# Non-chunked result - extract prompt_idx from request_id
parts = result.request_id.split("-")
try:
# Last part should be prompt index
prompt_idx = int(parts[-1])
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
# Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
if prompt_idx in prompt_aggregators:
# Finalize MEAN aggregation for this chunked prompt
aggregator = prompt_aggregators[prompt_idx]
weighted_sum = aggregator["weighted_sum"]
total_weight = aggregator["total_weight"]
if (
weighted_sum is not None
and isinstance(weighted_sum, torch.Tensor)
and isinstance(total_weight, (int, float))
and total_weight > 0
):
# Compute final mean embedding
final_embedding = weighted_sum / total_weight
# Create a PoolingRequestOutput
# for the aggregated result
pooling_output_data = PoolingOutput(data=final_embedding)
# Get original prompt token IDs for this prompt
original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
)
original_token_ids = cast(TextTokensPrompt, original_prompt)[
"prompt_token_ids"
]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"],
prompt_token_ids=original_token_ids,
outputs=pooling_output_data,
num_cached_tokens=0,
finished=True,
)
final_res_batch.append(pooling_request_output)
else:
return self.create_error_response(
f"Failed to aggregate chunks for prompt {prompt_idx}"
)
elif prompt_idx in short_prompts_results:
final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}"
)
ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Request | None = None,
) -> EmbeddingResponse | ErrorResponse:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}"
)
ctx = EmbeddingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: ServeContext[EmbeddingRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
LoadLoRAAdapterRequest,
ModelCard,
ModelList,
ModelPermission,
UnloadLoRAAdapterRequest,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.utils.counter import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: str | None = None
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
engine_client: EngineClient,
base_model_paths: list[BaseModelPath],
*,
lora_modules: list[LoRAModulePath] | None = None,
):
super().__init__()
self.engine_client = engine_client
self.base_model_paths = base_model_paths
self.static_lora_modules = lora_modules
self.lora_requests: dict[str, LoRARequest] = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers():
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name)
)
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.processor = self.engine_client.processor
self.io_processor = self.engine_client.io_processor
self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoRAAdapterRequest(
lora_path=lora.path, lora_name=lora.name
)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name
)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all
adapters"""
model_cards = [
ModelCard(
id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(
id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name
if lora.base_model_name
else self.base_model_paths[0].name,
permission=[ModelPermission()],
)
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_path = request.lora_path
unique_id = self.lora_id_counter.inc(1)
lora_request = LoRARequest(
lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path
)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except Exception as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
return create_error_response(
message=str(e), err_type=error_type, status_code=status_code
)
self.lora_requests[lora_name] = lora_request
logger.info(
"Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path
)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | str:
lora_name = request.lora_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_unload_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
# Safe to delete now since we hold the lock
del self.lora_requests[lora_name]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# Check if the lora adapter with the given name already exists
if request.lora_name in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' has already been "
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
async def _check_unload_lora_adapter_request(
self, request: UnloadLoRAAdapterRequest
) -> ErrorResponse | None:
# Check if 'lora_name' is not provided return an error
if not request.lora_name:
return create_error_response(
message="'lora_name' needs to be provided to unload a LoRA adapter.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST,
)
# Check if the lora adapter with the given name exists
if request.lora_name not in self.lora_requests:
return create_error_response(
message=f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
return None
async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
if lora_name in self.lora_requests:
return self.lora_requests[lora_name]
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests[lora_name] = lora_request
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name,
resolver.__class__.__name__,
)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.",
lora_name,
resolver.__class__.__name__,
e,
)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(
f"LoRA adapter '{lora_name}' was found but could not be loaded."
),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST,
)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
)
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import time
from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingBytesResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest,
PoolingResponse,
PoolingResponseData,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
logger = init_logger(__name__)
class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.supported_tasks = supported_tasks
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Request | None = None,
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported"
)
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
if not isinstance(engine_prompts, Sequence) or isinstance(
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
if is_io_processor_request:
assert self.io_processor is not None and isinstance(
request, IOProcessorRequest
)
pooling_params = self.io_processor.validate_or_generate_params()
else:
pooling_params = request.to_pooling_params()
pooling_task: PoolingTask
if request.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
else:
pooling_task = request.task
if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
request.encoding_format,
request.embed_dtype,
request.endianness,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
def encode_float_base64():
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=encode_pooling_output(
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def encode_bytes():
body, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
endianness=endianness,
)
metadata = {
"id": request_id,
"created": created_time,
"model": model_name,
"data": items,
"usage": usage,
}
return PoolingBytesResponse(
body=body,
metadata=json.dumps(metadata),
)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes":
return encode_bytes()
else:
assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,503 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from typing import Any
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
RerankDocument,
RerankRequest,
RerankResponse,
RerankResult,
RerankUsage,
ScoreRequest,
ScoreResponse,
ScoreResponseData,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import (
ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)
class ServingScores(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
async def _embedding_score(
self,
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
input_texts = texts_1 + texts_2
engine_prompts: list[TokensPrompt] = []
tokenize_async = make_async(
tokenizer.__call__, executor=self._tokenizer_executor
)
tokenization_kwargs = tokenization_kwargs or {}
tokenized_prompts = await asyncio.gather(
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts)
)
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(
request, tok_result["input_ids"], input_text
)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput] = []
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
emb_texts_1: list[PoolingRequestOutput] = []
emb_texts_2: list[PoolingRequestOutput] = []
for i in range(0, len(texts_1)):
assert (emb := embeddings[i]) is not None
emb_texts_1.append(emb)
for i in range(len(texts_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_texts_2.append(emb)
if len(emb_texts_1) == 1:
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
final_res_batch = _cosine_similarity(
tokenizer=tokenizer, embed_1=emb_texts_1, embed_2=emb_texts_2
)
return final_res_batch
def _preprocess_score(
self,
request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
) -> tuple[str, TokensPrompt]:
model_config = self.model_config
full_prompt, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=data_1,
data_2=data_2,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return full_prompt, engine_prompt
async def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
if isinstance(tokenizer, MistralTokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
tokenization_kwargs = tokenization_kwargs or {}
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
preprocess_async = make_async(
self._preprocess_score, executor=self._tokenizer_executor
)
preprocessed_prompts = await asyncio.gather(
*(
preprocess_async(
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2,
)
for t1, t2 in input_pairs
)
)
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params()
try:
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
request_prompts[i],
params=default_pooling_params,
lora_request=lora_request,
)
if token_type_ids := engine_prompt.pop("token_type_ids", None):
pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
else:
pooling_params = default_pooling_params
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
engine_prompts
)
async for i, res in result_generator:
final_res_batch[i] = res
return [out for out in final_res_batch if out is not None]
async def _run_scoring(
self,
data_1: list[str] | str | ScoreMultiModalParam,
data_2: list[str] | str | ScoreMultiModalParam,
request: ScoreRequest | RerankRequest,
request_id: str,
raw_request: Request | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
if not self.model_config.is_multimodal_model and (
isinstance(data_1, dict) or isinstance(data_2, dict)
):
raise ValueError(
f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501
)
if isinstance(data_1, str):
data_1 = [data_1]
elif isinstance(data_1, dict):
data_1 = data_1.get("content") # type: ignore[assignment]
if isinstance(data_2, str):
data_2 = [data_2]
elif isinstance(data_2, dict):
data_2 = data_2.get("content") # type: ignore[assignment]
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
return await self._embedding_score(
tokenizer=tokenizer,
texts_1=data_1, # type: ignore[arg-type]
texts_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)
async def create_score(
self,
request: ScoreRequest,
raw_request: Request | None = None,
) -> ScoreResponse | ErrorResponse:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.time())
try:
final_res_batch = await self._run_scoring(
request.text_1,
request.text_2,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_score_response(
final_res_batch,
request_id,
created_time,
self.models.model_name(),
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None
) -> RerankResponse | ErrorResponse:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"rerank-{self._base_request_id(raw_request)}"
documents = request.documents
top_n = (
request.top_n
if request.top_n > 0
else (
len(documents)
if isinstance(documents, list)
else len(documents["content"])
)
)
try:
final_res_batch = await self._run_scoring(
request.query,
documents,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_rerank_response(
final_res_batch,
request_id,
self.models.model_name(),
documents,
top_n,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def request_output_to_score_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: list[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def request_output_to_rerank_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
model_name: str,
documents: list[str] | ScoreMultiModalParam,
top_n: int,
) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: list[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx])
if isinstance(documents, list)
else RerankDocument(multi_modal=documents["content"][idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens),
)

View File

@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Final
import jinja2
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse,
TokenizerInfoResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_tokenize(
self,
request: TokenizeRequest,
raw_request: Request,
) -> TokenizeResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{self._base_request_id(raw_request)}"
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else:
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
config=self._build_render_config(request),
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = []
for engine_prompt in engine_prompts:
self._log_inputs(
request_id, engine_prompt, params=None, lora_request=lora_request
)
if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
token_strs = None
if request.return_token_strs:
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
return TokenizeResponse(
tokens=input_ids,
token_strs=token_strs,
count=len(input_ids),
max_model_len=self.max_model_len,
)
async def create_detokenize(
self,
request: DetokenizeRequest,
raw_request: Request,
) -> DetokenizeResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{self._base_request_id(raw_request)}"
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(
request_id, request.tokens, params=None, lora_request=lora_request
)
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)
async def get_tokenizer_info(
self,
) -> TokenizerInfoResponse | ErrorResponse:
"""Get comprehensive tokenizer information."""
try:
tokenizer = await self.engine_client.get_tokenizer()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
return TokenizerInfoResponse(**info)
except Exception as e:
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
chat_template: str | None
def to_dict(self) -> dict[str, Any]:
"""Return the tokenizer configuration."""
return self._get_tokenizer_config()
def _get_tokenizer_config(self) -> dict[str, Any]:
"""Get tokenizer configuration directly from the tokenizer object."""
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
# Remove file path fields
config.pop("vocab_file", None)
config.pop("merges_file", None)
config = self._make_json_serializable(config)
config["tokenizer_class"] = type(self.tokenizer).__name__
if self.chat_template:
config["chat_template"] = self.chat_template
return config
def _make_json_serializable(self, obj):
"""Convert any non-JSON-serializable objects to serializable format."""
if hasattr(obj, "content"):
return obj.content
elif isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj

View File

@@ -0,0 +1,269 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
from collections.abc import Sequence as GenericSequence
from fastapi import Request
# yapf: disable
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class ServingTokens(OpenAIServing):
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_prompt_tokens_details: bool = False,
enable_log_outputs: bool = False,
):
super().__init__(engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
if force_no_detokenize:
logger.info("Tokens-only mode is enabled, skipping detokenization "
"step for incoming requests.")
async def serve_tokens(
self,
request: GenerateRequest,
raw_request: Request | None = None
) -> GenerateResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
lora_request = None
lora_request = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
request_id = "generate-tokens-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
try:
sampling_params = request.sampling_params
if self.force_no_detokenize:
sampling_params.detokenize = False
self._log_inputs(request_id,
request.token_ids,
params=sampling_params,
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
except ValueError as e:
return self.create_error_response(str(e))
# TODO(NickLucche): Implement streaming response
try:
assert result_generator is not None
return await self.serve_tokens_full_generator(
request, result_generator, request_id, model_name,
request_metadata)
except ValueError as e:
return self.create_error_response(str(e))
async def serve_tokens_full_generator(
self,
request: GenerateRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str,
model_name: str,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | GenerateResponse:
created_time = int(time.time())
final_res: RequestOutput | None = None
sampling_params: SamplingParams = request.sampling_params
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
assert final_res is not None
choices: list[GenerateResponseChoice] = []
num_generated_tokens = 0
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
# This is top_logprobs in completions API
if sampling_params.logprobs:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_tokens_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=sampling_params.logprobs,
)
else:
logprobs = None
choice_data = GenerateResponseChoice(
index=output.index,
logprobs=logprobs,
finish_reason=output.finish_reason
if output.finish_reason else "stop",
token_ids=as_list(output.token_ids))
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
# This info is not available at the /coordinator level
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
response = GenerateResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[
choice.index].token_ids
if output_token_ids:
# Log token_ids only.
self.request_logger.log_outputs(
request_id=request_id,
outputs="",
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response
def _create_tokens_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int | None = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content: list[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
token = f"token_id:{token_id}"
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None or step_top_logprobs.get(
token_id) is None:
logprobs_content.append(
ChatCompletionLogProbsContent(token=token, ))
else:
step_token = step_top_logprobs[token_id]
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
logprob=max(step_token.logprob, -9999.0),
top_logprobs=[
ChatCompletionLogProb(
token=token,
logprob=max(p[1].logprob, -9999.0),
) for i, p in enumerate(step_top_logprobs.items())
if num_output_top_logprobs
and i < num_output_top_logprobs
]))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
RequestResponseMetadata,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionStreamResponse,
TranslationRequest,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
logger = init_logger(__name__)
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranscriptionResponse,
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self,
request: TranscriptionRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
async for chunk in generator:
yield chunk
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse:
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranslationResponse,
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self,
request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)
async for chunk in generator:
yield chunk

View File

@@ -0,0 +1,405 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
DeltaMessage,
ErrorResponse,
RequestResponseMetadata,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionStreamResponse,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationStreamResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsTranscription
from vllm.outputs import RequestOutput
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
T = TypeVar("T", bound=SpeechToTextResponse)
logger = init_logger(__name__)
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
)
self.enable_force_include_usage = enable_force_include_usage
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params,
)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = (
self.model_cls.validate_language(request.to_language)
if request.to_language
else None
)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s
)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
prompts.append(prompt)
return prompts, duration
async def _create_speech_to_text(
self,
audio_data: bytes,
request: SpeechToTextRequest,
raw_request: Request,
response_class: type[T],
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
) -> T | AsyncGenerator[str, None] | ErrorResponse:
"""Base method for speech-to-text operations like transcription and
translation."""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ["text", "json"]:
return self.create_error_response(
"Currently only support response_format `text` or `json`"
)
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=lora_request,
)
list_result_generator = [
self.engine_client.generate(
prompt,
sampling_params,
request_id,
lora_request=lora_request,
)
for prompt in prompts
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if request.stream:
return stream_generator_method(
request, list_result_generator, request_id, request_metadata, duration_s
)
# Non-streaming response.
try:
assert list_result_generator is not None
text = ""
for result_generator in list_result_generator:
async for op in result_generator:
text += op.outputs[0].text
if self.task_type == "transcribe":
# add usage in TranscriptionResponse.
usage = {
"type": "duration",
# rounded up as per openAI specs
"seconds": int(math.ceil(duration_s)),
}
final_response = cast(T, response_class(text=text, usage=usage))
else:
# no usage in response for translation task
final_response = cast(T, response_class(text=text)) # type: ignore[call-arg]
return final_response
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _speech_to_text_stream_generator(
self,
request: SpeechToTextRequest,
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
| type[TranslationResponseStreamChoice],
stream_response_class: type[TranscriptionStreamResponse]
| type[TranslationStreamResponse],
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
completion_tokens = 0
num_prompt_tokens = 0
include_usage = self.enable_force_include_usage or request.stream_include_usage
include_continuous_usage = (
request.stream_continuous_usage_stats
if include_usage and request.stream_continuous_usage_stats
else False
)
try:
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config, self.model_config
):
num_prompt_tokens += audio_tokens
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = response_stream_choice_class(delta=delta_message)
else:
# Model is finished generating.
choice_data = response_stream_choice_class(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if include_usage:
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
def _split_audio(
self, audio_data: np.ndarray, sample_rate: int
) -> list[np.ndarray]:
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start, search_end)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
min_energy_window = self.asr_config.min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i : i + min_energy_window]
energy = (window**2).mean() ** 0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
__all__ = ["ToolParser", "ToolParserManager"]
"""
Register a lazy module mapping.
Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
_TOOL_PARSERS_TO_REGISTER = {
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
}
def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)
register_lazy_tool_parsers()

Some files were not shown because too many files have changed in this diff Show More