v1.0
This commit is contained in:
319
entrypoints/utils.py
Normal file
319
entrypoints/utils.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
resolve_mistral_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
VLLM_SUBCMD_PARSER_EPILOG = (
|
||||
"For full list: vllm {subcmd} --help=all\n"
|
||||
"For a section: vllm {subcmd} --help=ModelConfig (case-insensitive)\n" # noqa: E501
|
||||
"For a flag: vllm {subcmd} --help=max-model-len (_ or - accepted)\n" # noqa: E501
|
||||
"Documentation: https://docs.vllm.ai\n"
|
||||
)
|
||||
|
||||
|
||||
async def listen_for_disconnect(request: Request) -> None:
|
||||
"""Returns if a disconnect message is received"""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
# If load tracking is enabled *and* the counter exists, decrement
|
||||
# it. Combines the previous nested checks into a single condition
|
||||
# to satisfy the linter rule.
|
||||
if getattr(
|
||||
request.app.state, "enable_server_load_tracking", False
|
||||
) and hasattr(request.app.state, "server_load_metrics"):
|
||||
request.app.state.server_load_metrics -= 1
|
||||
break
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
"""Decorator that allows a route handler to be cancelled by client
|
||||
disconnections.
|
||||
|
||||
This does _not_ use request.is_disconnected, which does not work with
|
||||
middleware. Instead this follows the pattern from
|
||||
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
||||
to wait for an http disconnect message, and the other to do the work that we
|
||||
want done. When the first task finishes, the other is cancelled.
|
||||
|
||||
A core assumption of this method is that the body of the request has already
|
||||
been read. This is a safe assumption to make for fastapi handlers that have
|
||||
already parsed the body of the request into a pydantic model for us.
|
||||
This decorator is unsafe to use elsewhere, as it will consume and throw away
|
||||
all incoming messages for the request while it looks for a disconnect
|
||||
message.
|
||||
|
||||
In the case where a `StreamingResponse` is returned by the handler, this
|
||||
wrapper will stop listening for disconnects and instead the response object
|
||||
will start listening for disconnects.
|
||||
"""
|
||||
|
||||
# Functools.wraps is required for this wrapper to appear to fastapi as a
|
||||
# normal route handler, with the correct request type hinting.
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# The request is either the second positional arg or `raw_request`
|
||||
request = args[1] if len(args) > 1 else kwargs["raw_request"]
|
||||
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if handler_task in done:
|
||||
return handler_task.result()
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def decrement_server_load(request: Request):
|
||||
request.app.state.server_load_metrics -= 1
|
||||
|
||||
|
||||
def load_aware_call(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
raw_request = kwargs.get("raw_request", args[1] if len(args) > 1 else None)
|
||||
|
||||
if raw_request is None:
|
||||
raise ValueError(
|
||||
"raw_request required when server load tracking is enabled"
|
||||
)
|
||||
|
||||
if not getattr(raw_request.app.state, "enable_server_load_tracking", False):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# ensure the counter exists
|
||||
if not hasattr(raw_request.app.state, "server_load_metrics"):
|
||||
raw_request.app.state.server_load_metrics = 0
|
||||
|
||||
raw_request.app.state.server_load_metrics += 1
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
except Exception:
|
||||
raw_request.app.state.server_load_metrics -= 1
|
||||
raise
|
||||
|
||||
if isinstance(response, (JSONResponse, StreamingResponse)):
|
||||
if response.background is None:
|
||||
response.background = BackgroundTask(decrement_server_load, raw_request)
|
||||
elif isinstance(response.background, BackgroundTasks):
|
||||
response.background.add_task(decrement_server_load, raw_request)
|
||||
elif isinstance(response.background, BackgroundTask):
|
||||
# Convert the single BackgroundTask to BackgroundTasks
|
||||
# and chain the decrement_server_load task to it
|
||||
tasks = BackgroundTasks()
|
||||
tasks.add_task(
|
||||
response.background.func,
|
||||
*response.background.args,
|
||||
**response.background.kwargs,
|
||||
)
|
||||
tasks.add_task(decrement_server_load, raw_request)
|
||||
response.background = tasks
|
||||
else:
|
||||
raw_request.app.state.server_load_metrics -= 1
|
||||
|
||||
return response
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def cli_env_setup():
|
||||
# The safest multiprocessing method is `spawn`, as the default `fork` method
|
||||
# is not compatible with some accelerators. The default method will be
|
||||
# changing in future versions of Python, so we should use it explicitly when
|
||||
# possible.
|
||||
#
|
||||
# We only set it here in the CLI entrypoint, because changing to `spawn`
|
||||
# could break some existing code using vLLM as a library. `spawn` will cause
|
||||
# unexpected behavior if the code is not protected by
|
||||
# `if __name__ == "__main__":`.
|
||||
#
|
||||
# References:
|
||||
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
|
||||
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
|
||||
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
|
||||
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
|
||||
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
|
||||
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _validate_truncation_size(
|
||||
max_model_len: int,
|
||||
truncate_prompt_tokens: int | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> int | None:
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= -1:
|
||||
truncate_prompt_tokens = max_model_len
|
||||
|
||||
if truncate_prompt_tokens > max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({max_model_len})."
|
||||
f" Please, select a smaller truncation size."
|
||||
)
|
||||
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
else:
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = False
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
|
||||
def get_max_tokens(
|
||||
max_model_len: int,
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
input_length: int,
|
||||
default_sampling_params: dict,
|
||||
) -> int:
|
||||
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
return min(
|
||||
val
|
||||
for val in (
|
||||
default_max_tokens,
|
||||
max_tokens,
|
||||
max_output_tokens,
|
||||
default_sampling_params.get("max_tokens"),
|
||||
)
|
||||
if val is not None
|
||||
)
|
||||
|
||||
|
||||
def log_non_default_args(args: Namespace | EngineArgs):
|
||||
non_default_args = {}
|
||||
|
||||
# Handle Namespace
|
||||
if isinstance(args, Namespace):
|
||||
parser = make_arg_parser(FlexibleArgumentParser())
|
||||
for arg, default in vars(parser.parse_args([])).items():
|
||||
if default != getattr(args, arg):
|
||||
non_default_args[arg] = getattr(args, arg)
|
||||
|
||||
# Handle EngineArgs instance
|
||||
elif isinstance(args, EngineArgs):
|
||||
default_args = EngineArgs(model=args.model) # Create default instance
|
||||
for field in dataclasses.fields(args):
|
||||
current_val = getattr(args, field.name)
|
||||
default_val = getattr(default_args, field.name)
|
||||
if current_val != default_val:
|
||||
non_default_args[field.name] = current_val
|
||||
if default_args.model != EngineArgs.model:
|
||||
non_default_args["model"] = default_args.model
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported argument type. Must be Namespace or EngineArgs instance."
|
||||
)
|
||||
|
||||
logger.info("non-default args: %s", non_default_args)
|
||||
|
||||
|
||||
def should_include_usage(
|
||||
stream_options: StreamOptions | None, enable_force_include_usage: bool
|
||||
) -> tuple[bool, bool]:
|
||||
if stream_options:
|
||||
include_usage = stream_options.include_usage or enable_force_include_usage
|
||||
include_continuous_usage = include_usage and bool(
|
||||
stream_options.continuous_usage_stats
|
||||
)
|
||||
else:
|
||||
include_usage, include_continuous_usage = enable_force_include_usage, False
|
||||
return include_usage, include_continuous_usage
|
||||
|
||||
|
||||
def process_lora_modules(
|
||||
args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None
|
||||
) -> list[LoRAModulePath]:
|
||||
lora_modules = args_lora_modules
|
||||
if default_mm_loras:
|
||||
default_mm_lora_paths = [
|
||||
LoRAModulePath(
|
||||
name=modality,
|
||||
path=lora_path,
|
||||
)
|
||||
for modality, lora_path in default_mm_loras.items()
|
||||
]
|
||||
if args_lora_modules is None:
|
||||
lora_modules = default_mm_lora_paths
|
||||
else:
|
||||
lora_modules += default_mm_lora_paths
|
||||
return lora_modules
|
||||
|
||||
|
||||
async def process_chat_template(
|
||||
args_chat_template: Path | str | None,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
resolved_chat_template = load_chat_template(args_chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
tokenizer = await engine_client.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# The warning is logged in resolve_mistral_chat_template.
|
||||
resolved_chat_template = resolve_mistral_chat_template(
|
||||
chat_template=resolved_chat_template
|
||||
)
|
||||
else:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template != resolved_chat_template:
|
||||
logger.warning(
|
||||
"Using supplied chat template: %s\n"
|
||||
"It is different from official chat template '%s'. "
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template,
|
||||
model_config.model,
|
||||
)
|
||||
return resolved_chat_template
|
||||
Reference in New Issue
Block a user