Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the `if typing.TYPE_CHECKING:` guard,
**except** for a short whitelist.
"""
import ast
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Final
INIT_PATH: Final = Path("vllm/__init__.py")
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset(
{
"vllm.env_override",
}
)
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset(
{
".version",
}
)
def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name):
guard_is_type_checking = (
test.value.id == "typing" and test.attr == "TYPE_CHECKING"
)
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS:
violations.append(
(
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if (
_is_internal(node.module, level=node.level)
and module_as_written not in ALLOWED_FROM_MODULES
):
violations.append(
(
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
Visitor().visit(tree)
if violations:
_fail(violations)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
# The pickle and cloudpickle modules are known to be unsafe when deserializing
# data from potentially untrusted parties. They have resulted in multiple CVEs
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
# Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
# pickle
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py",
"vllm/distributed/utils.py",
"vllm/distributed/parallel_state.py",
"vllm/distributed/device_communicators/all_reduce_utils.py",
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py",
"tests/tokenizers_/test_hf.py",
"tests/utils_/test_hashing.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py",
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
# cloudpickle
"vllm/v1/executor/multiproc_executor.py",
"vllm/v1/executor/ray_executor.py",
"vllm/entrypoints/llm.py",
"tests/utils.py",
# pickle and cloudpickle
"vllm/v1/serial_utils.py",
}
PICKLE_RE = re.compile(
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)"
)
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
if PICKLE_RE.match(line):
print(
f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import"
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if filename in ALLOWED_FILES:
continue
returncode |= scan_file(filename)
return returncode
def test_regex():
test_cases = [
# Should match
("import pickle", True),
("import cloudpickle", True),
("import pickle as pkl", True),
("import cloudpickle as cpkl", True),
("from pickle import *", True),
("from cloudpickle import dumps", True),
("from pickle import dumps, loads", True),
("from cloudpickle import (dumps, loads)", True),
(" import pickle", True),
("\timport cloudpickle", True),
("from pickle import loads", True),
# Should not match
("import somethingelse", False),
("from somethingelse import pickle", False),
("# import pickle", False),
("print('import pickle')", False),
("import pickleas as asdf", False),
]
for i, (line, should_match) in enumerate(test_cases):
result = bool(PICKLE_RE.match(line))
assert result == should_match, (
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
)
print("All regex tests passed.")
if __name__ == "__main__":
if "--test-regex" in sys.argv:
test_regex()
else:
sys.exit(main())

View File

@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from enum import Enum
class SPDXStatus(Enum):
"""SPDX header status enumeration"""
EMPTY = "empty" # empty __init__.py
COMPLETE = "complete"
MISSING_LICENSE = "missing_license" # Only has copyright line
MISSING_COPYRIGHT = "missing_copyright" # Only has license line
MISSING_BOTH = "missing_both" # Completely missing
FULL_SPDX_HEADER = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"
)
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
def check_spdx_header_status(file_path):
"""Check SPDX header status of the file"""
with open(file_path, encoding="UTF-8") as file:
lines = file.readlines()
if not lines:
# Empty file
return SPDXStatus.EMPTY
# Skip shebang line
start_idx = 0
if lines and lines[0].startswith("#!"):
start_idx = 1
has_license = False
has_copyright = False
# Check all lines for SPDX headers (not just the first two)
for i in range(start_idx, len(lines)):
line = lines[i].strip()
if line == LICENSE_LINE:
has_license = True
elif line == COPYRIGHT_LINE:
has_copyright = True
# Determine status based on what we found
if has_license and has_copyright:
return SPDXStatus.COMPLETE
elif has_license and not has_copyright:
# Only has license line
return SPDXStatus.MISSING_COPYRIGHT
# Only has copyright line
elif not has_license and has_copyright:
return SPDXStatus.MISSING_LICENSE
else:
# Completely missing both lines
return SPDXStatus.MISSING_BOTH
def add_header(file_path, status):
"""Add or supplement SPDX header based on status"""
with open(file_path, "r+", encoding="UTF-8") as file:
lines = file.readlines()
file.seek(0, 0)
file.truncate()
if status == SPDXStatus.MISSING_BOTH:
# Completely missing, add complete header
if lines and lines[0].startswith("#!"):
# Preserve shebang line
file.write(lines[0])
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines[1:])
else:
# Add header directly
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines)
elif status == SPDXStatus.MISSING_COPYRIGHT:
# Only has license line, need to add copyright line
# Find the license line and add copyright line after it
for i, line in enumerate(lines):
if line.strip() == LICENSE_LINE:
# Insert copyright line after license line
lines.insert(
i + 1,
f"{COPYRIGHT_LINE}\n",
)
break
file.writelines(lines)
elif status == SPDXStatus.MISSING_LICENSE:
# Only has copyright line, need to add license line
# Find the copyright line and add license line before it
for i, line in enumerate(lines):
if line.strip() == COPYRIGHT_LINE:
# Insert license line before copyright line
lines.insert(i, f"{LICENSE_LINE}\n")
break
file.writelines(lines)
def main():
"""Main function"""
files_missing_both = []
files_missing_copyright = []
files_missing_license = []
for file_path in sys.argv[1:]:
status = check_spdx_header_status(file_path)
if status == SPDXStatus.MISSING_BOTH:
files_missing_both.append(file_path)
elif status == SPDXStatus.MISSING_COPYRIGHT:
files_missing_copyright.append(file_path)
elif status == SPDXStatus.MISSING_LICENSE:
files_missing_license.append(file_path)
else:
continue
# Collect all files that need fixing
all_files_to_fix = (
files_missing_both + files_missing_copyright + files_missing_license
)
if all_files_to_fix:
print("The following files are missing the SPDX header:")
if files_missing_both:
for file_path in files_missing_both:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_BOTH)
if files_missing_copyright:
for file_path in files_missing_copyright:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_COPYRIGHT)
if files_missing_license:
for file_path in files_missing_license:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_LICENSE)
sys.exit(1 if all_files_to_fix else 0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import sys
import regex as re
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
# the way allowed to import triton
ALLOWED_LINES = {
"from vllm.triton_utils import triton",
"from vllm.triton_utils import tl",
"from vllm.triton_utils import tl, triton",
}
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
def is_allowed_file(current_file: str) -> bool:
return current_file in ALLOWED_FILES
def is_forbidden_import(line: str) -> bool:
stripped = line.strip()
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
def parse_diff(diff: str) -> list[str]:
violations = []
current_file = None
current_lineno = None
skip_allowed_file = False
for line in diff.splitlines():
if line.startswith("+++ b/"):
current_file = line[6:]
skip_allowed_file = is_allowed_file(current_file)
elif skip_allowed_file:
continue
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
elif line.startswith("+") and not line.startswith("++"):
current_lineno += 1
code_line = line[1:]
if is_forbidden_import(code_line):
violations.append(
f"{current_file}:{current_lineno}: {code_line.strip()}"
)
return violations
def get_diff(diff_type: str) -> str:
if diff_type == "staged":
return subprocess.check_output(
["git", "diff", "--cached", "--unified=0"], text=True
)
elif diff_type == "unstaged":
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
else:
raise ValueError(f"Unknown diff_type: {diff_type}")
def main():
all_violations = []
for diff_type in ["staged", "unstaged"]:
try:
diff_output = get_diff(diff_type)
violations = parse_diff(diff_output)
all_violations.extend(violations)
except subprocess.CalledProcessError as e:
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
if all_violations:
print(
"❌ Forbidden direct `import triton` detected."
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
)
for v in all_violations:
print(f"{v}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
from pathlib import Path
import regex as re
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
ALLOWED_PATTERNS = [
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
re.compile(r"^\s*import\s+regex\s*$"),
]
def get_staged_python_files() -> list[str]:
try:
result = subprocess.run(
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
capture_output=True,
text=True,
check=True,
)
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
return [f for f in files if f.endswith(".py")]
except subprocess.CalledProcessError:
return []
def is_forbidden_import(line: str) -> bool:
line = line.strip()
return bool(
FORBIDDEN_PATTERNS.match(line)
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
)
def check_file(filepath: str) -> list[tuple[int, str]]:
violations = []
try:
with open(filepath, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if is_forbidden_import(line):
violations.append((line_num, line.strip()))
except (OSError, UnicodeDecodeError):
pass
return violations
def main() -> int:
files = get_staged_python_files()
if not files:
return 0
total_violations = 0
for filepath in files:
if not Path(filepath).exists():
continue
if filepath == "setup.py":
continue
violations = check_file(filepath)
if violations:
print(f"\n{filepath}:")
for line_num, line in violations:
print(f" Line {line_num}: {line}")
total_violations += 1
if total_violations > 0:
print(f"\n💡 Found {total_violations} violation(s).")
print("❌ Please replace 'import re' with 'import regex as re'")
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
print("✅ Allowed imports:")
print(" - import regex as re")
print(" - import regex") # noqa: E501
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generates specialized requirements files for nightly PyTorch testing.
This script reads the main test requirements input file (`requirements/test.in`)
and splits its content into two files:
1. `requirements/nightly_torch_test.txt`: Contains dependencies
except PyTorch-related.
2. `torch_nightly_test.txt`: Contains only PyTorch-related packages.
"""
input_file = "requirements/test.in"
output_file = "requirements/nightly_torch_test.txt"
# white list of packages that are not compatible with PyTorch nightly directly
# with pip install. Please add your package to this list if it is not compatible
# or make the dependency test fails.
white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"]
with open(input_file) as f:
lines = f.readlines()
skip_next = False
for line in lines:
if skip_next:
if line.startswith((" ", "\t")) or line.strip() == "":
continue
skip_next = False
if any(k in line.lower() for k in white_list):
skip_next = True
continue

158
tools/pre_commit/mypy.py Executable file
View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.
This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
"""
import subprocess
import sys
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
"vllm/executor",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/worker",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/attention",
"vllm/compilation",
"vllm/lora",
"vllm/model_executor",
# v1 related
"vllm/v1/attention",
"vllm/v1/kv_offload",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
"vllm/engine/arg_utils.py",
"vllm/model_executor/parallel_utils",
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Group changed files into different mypy calls.
Args:
changed_files: List of changed files.
Returns:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
# Skip files which should be ignored completely
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
return file_groups
def mypy(
targets: list[str],
python_version: str | None,
follow_imports: str | None,
file_group: str,
) -> int:
"""
Run mypy on the given targets.
Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
the default mypy version.
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
Returns:
The return code from mypy.
"""
args = ["mypy"]
if python_version is not None:
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
)
return returncode
if __name__ == "__main__":
sys.exit(main())

15
tools/pre_commit/png-lint.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
# Ensure that *.excalidraw.png files have the excalidraw metadata
# embedded in them. This ensures they can be loaded back into
# the tool and edited in the future.
find . -iname '*.excalidraw.png' | while read -r file; do
if git check-ignore -q "$file"; then
continue
fi
if ! grep -q "excalidraw+json" "$file"; then
echo "$file was not exported from excalidraw with 'Embed Scene' enabled."
exit 1
fi
done

22
tools/pre_commit/shellcheck.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
set -e
scversion="stable"
if [ -d "shellcheck-${scversion}" ]; then
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
if ! [ -x "$(command -v shellcheck)" ]; then
if [ "$(uname -s)" != "Linux" ] || [ "$(uname -m)" != "x86_64" ]; then
echo "Please install shellcheck: https://github.com/koalaman/shellcheck?tab=readme-ov-file#installing"
exit 1
fi
# automatic local install if linux x86_64
wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
# TODO - fix warnings in .buildkite/scripts/hardware_ci/run-amd-test.sh
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/scripts/hardware_ci/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'

View File

@@ -0,0 +1,81 @@
#!/bin/bash
# Update Dockerfile dependency graph when docker/Dockerfile changes.
# This script is designed to be used as a pre-commit hook.
set -euo pipefail
# Accept file paths as arguments
FILES=("$@")
# Check if docker/Dockerfile is among the provided files
if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then
echo "docker/Dockerfile has changed, attempting to update dependency graph..."
# Check if Docker is installed and running
if ! command -v docker &> /dev/null; then
echo "Warning: Docker command not found. Skipping Dockerfile graph update."
echo "Please install Docker to automatically update the graph: https://docs.docker.com/get-docker/"
exit 0
fi
if ! docker info &> /dev/null; then
echo "Warning: Docker daemon is not running. Skipping Dockerfile graph update."
echo "Please start Docker to automatically update the graph."
exit 0
fi
# Define the target file path
TARGET_GRAPH_FILE="docs/assets/contributing/dockerfile-stages-dependency.png"
# Ensure target directory exists
mkdir -p "$(dirname "$TARGET_GRAPH_FILE")"
# Store old image hash in a variable if the file exists
OLD_HASH=""
if [ -f "$TARGET_GRAPH_FILE" ]; then
OLD_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
fi
# Generate Dockerfile graph
echo "Running dockerfilegraph tool..."
docker run \
--rm \
--user "$(id -u):$(id -g)" \
--workdir /workspace \
--volume "$(pwd)":/workspace \
ghcr.io/patrickhoefler/dockerfilegraph:alpine \
--output png \
--dpi 200 \
--max-label-length 50 \
--filename docker/Dockerfile \
--legend
echo "Finding generated PNG file..."
# Check for Dockerfile.png in the root directory (most likely location)
if [ -f "./Dockerfile.png" ]; then
echo "Found generated file at: ./Dockerfile.png"
mv "./Dockerfile.png" "$TARGET_GRAPH_FILE"
else
# Try to find it elsewhere
DOCKERFILE_PNG=$(find . -name "Dockerfile.png" -type f | head -1)
if [ -n "$DOCKERFILE_PNG" ]; then
echo "Found generated file at: $DOCKERFILE_PNG"
mv "$DOCKERFILE_PNG" "$TARGET_GRAPH_FILE"
else
echo "Error: Could not find the generated PNG file"
find . -name "*.png" -type f -mmin -5
exit 1
fi
fi
# Check if the graph has changed
NEW_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
if [ "$NEW_HASH" != "$OLD_HASH" ]; then
echo "Graph has changed. Please stage the updated file: $TARGET_GRAPH_FILE"
exit 1
else
echo "No changes in graph detected."
fi
fi
exit 0

View File

@@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Ensures all fields in a config dataclass have default values
and that each field has a docstring.
"""
import ast
import inspect
import sys
from itertools import pairwise
import regex as re
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
Adapted from https://davidism.com/attribute-docstrings/
https://davidism.com/mit-license/
"""
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (
not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)
):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
class ConfigValidator(ast.NodeVisitor):
def __init__(self): ...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
if set(decorators) == {"config", "dataclass"}:
validate_class(node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
self.generic_visit(node)
def validate_class(class_node: ast.ClassDef):
attr_docs = get_attr_docs(class_node)
for stmt in class_node.body:
# A field is defined as a class variable that has a type annotation.
if isinstance(stmt, ast.AnnAssign):
# Skip ClassVar and InitVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id in {"ClassVar", "InitVar"}
):
continue
if isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.",
stmt,
)
if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.",
stmt,
)
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id == "Union"
and isinstance(stmt.annotation.slice, ast.Tuple)
):
args = stmt.annotation.slice.elts
literal_args = [
arg
for arg in args
if isinstance(arg, ast.Subscript)
and isinstance(arg.value, ast.Name)
and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
f"Field '{field_name}' in {class_node.name} must "
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".",
stmt,
)
def validate_ast(tree: ast.stmt):
ConfigValidator().visit(tree)
def validate_file(file_path: str):
try:
print(f"Validating {file_path} config dataclasses ", end="")
with open(file_path, encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=file_path)
validate_ast(tree)
except ValueError as e:
print(e)
raise SystemExit(1) from e
else:
print("")
def fail(message: str, node: ast.stmt):
raise ValueError(f"❌ line({node.lineno}): {message}")
def main():
for filename in sys.argv[1:]:
# Only run for Python files in vllm/ or tests/
if not re.match(r"^(vllm|tests)/.*\.py$", filename):
continue
# Only run if the file contains @config
with open(filename, encoding="utf-8") as f:
if "@config" in f.read():
validate_file(filename)
if __name__ == "__main__":
main()