Sync from v0.13
This commit is contained in:
111
tools/pre_commit/check_init_lazy_imports.py
Normal file
111
tools/pre_commit/check_init_lazy_imports.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Ensure we perform lazy loading in vllm/__init__.py.
|
||||
i.e: appears only within the `if typing.TYPE_CHECKING:` guard,
|
||||
**except** for a short whitelist.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
INIT_PATH: Final = Path("vllm/__init__.py")
|
||||
|
||||
# If you need to add items to whitelist, do it here.
|
||||
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"vllm.env_override",
|
||||
}
|
||||
)
|
||||
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
".version",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_internal(name: str | None, *, level: int = 0) -> bool:
|
||||
if level > 0:
|
||||
return True
|
||||
if name is None:
|
||||
return False
|
||||
return name.startswith("vllm.") or name == "vllm"
|
||||
|
||||
|
||||
def _fail(violations: Iterable[tuple[int, str]]) -> None:
|
||||
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr)
|
||||
for lineno, msg in violations:
|
||||
print(f" Line {lineno}: {msg}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
source = INIT_PATH.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(INIT_PATH))
|
||||
|
||||
violations: list[tuple[int, str]] = []
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._in_type_checking = False
|
||||
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
guard_is_type_checking = False
|
||||
test = node.test
|
||||
if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name):
|
||||
guard_is_type_checking = (
|
||||
test.value.id == "typing" and test.attr == "TYPE_CHECKING"
|
||||
)
|
||||
elif isinstance(test, ast.Name):
|
||||
guard_is_type_checking = test.id == "TYPE_CHECKING"
|
||||
|
||||
if guard_is_type_checking:
|
||||
prev = self._in_type_checking
|
||||
self._in_type_checking = True
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
self._in_type_checking = prev
|
||||
for child in node.orelse:
|
||||
self.visit(child)
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
if self._in_type_checking:
|
||||
return
|
||||
for alias in node.names:
|
||||
module_name = alias.name
|
||||
if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS:
|
||||
violations.append(
|
||||
(
|
||||
node.lineno,
|
||||
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if self._in_type_checking:
|
||||
return
|
||||
module_as_written = ("." * node.level) + (node.module or "")
|
||||
if (
|
||||
_is_internal(node.module, level=node.level)
|
||||
and module_as_written not in ALLOWED_FROM_MODULES
|
||||
):
|
||||
violations.append(
|
||||
(
|
||||
node.lineno,
|
||||
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
)
|
||||
)
|
||||
|
||||
Visitor().visit(tree)
|
||||
|
||||
if violations:
|
||||
_fail(violations)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
108
tools/pre_commit/check_pickle_imports.py
Normal file
108
tools/pre_commit/check_pickle_imports.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
# List of files (relative to repo root) that are allowed to import pickle or
|
||||
# cloudpickle
|
||||
#
|
||||
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
|
||||
# The pickle and cloudpickle modules are known to be unsafe when deserializing
|
||||
# data from potentially untrusted parties. They have resulted in multiple CVEs
|
||||
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
|
||||
# Before adding new uses of pickle/cloudpickle, please consider safer
|
||||
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
|
||||
# add to this list if absolutely necessary and after careful security review.
|
||||
ALLOWED_FILES = {
|
||||
# pickle
|
||||
"vllm/multimodal/hasher.py",
|
||||
"vllm/transformers_utils/config.py",
|
||||
"vllm/model_executor/models/registry.py",
|
||||
"vllm/compilation/caching.py",
|
||||
"vllm/distributed/utils.py",
|
||||
"vllm/distributed/parallel_state.py",
|
||||
"vllm/distributed/device_communicators/all_reduce_utils.py",
|
||||
"vllm/distributed/device_communicators/shm_broadcast.py",
|
||||
"vllm/distributed/device_communicators/shm_object_storage.py",
|
||||
"vllm/utils/hashing.py",
|
||||
"tests/tokenizers_/test_hf.py",
|
||||
"tests/utils_/test_hashing.py",
|
||||
"benchmarks/kernels/graph_machete_bench.py",
|
||||
"benchmarks/kernels/benchmark_lora.py",
|
||||
"benchmarks/kernels/benchmark_machete.py",
|
||||
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
|
||||
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
|
||||
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
|
||||
# cloudpickle
|
||||
"vllm/v1/executor/multiproc_executor.py",
|
||||
"vllm/v1/executor/ray_executor.py",
|
||||
"vllm/entrypoints/llm.py",
|
||||
"tests/utils.py",
|
||||
# pickle and cloudpickle
|
||||
"vllm/v1/serial_utils.py",
|
||||
}
|
||||
|
||||
PICKLE_RE = re.compile(
|
||||
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
||||
r"|from\s+(pickle|cloudpickle)\s+import\b)"
|
||||
)
|
||||
|
||||
|
||||
def scan_file(path: str) -> int:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if PICKLE_RE.match(line):
|
||||
print(
|
||||
f"{path}:{i}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found pickle/cloudpickle import"
|
||||
)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
returncode = 0
|
||||
for filename in sys.argv[1:]:
|
||||
if filename in ALLOWED_FILES:
|
||||
continue
|
||||
returncode |= scan_file(filename)
|
||||
return returncode
|
||||
|
||||
|
||||
def test_regex():
|
||||
test_cases = [
|
||||
# Should match
|
||||
("import pickle", True),
|
||||
("import cloudpickle", True),
|
||||
("import pickle as pkl", True),
|
||||
("import cloudpickle as cpkl", True),
|
||||
("from pickle import *", True),
|
||||
("from cloudpickle import dumps", True),
|
||||
("from pickle import dumps, loads", True),
|
||||
("from cloudpickle import (dumps, loads)", True),
|
||||
(" import pickle", True),
|
||||
("\timport cloudpickle", True),
|
||||
("from pickle import loads", True),
|
||||
# Should not match
|
||||
("import somethingelse", False),
|
||||
("from somethingelse import pickle", False),
|
||||
("# import pickle", False),
|
||||
("print('import pickle')", False),
|
||||
("import pickleas as asdf", False),
|
||||
]
|
||||
for i, (line, should_match) in enumerate(test_cases):
|
||||
result = bool(PICKLE_RE.match(line))
|
||||
assert result == should_match, (
|
||||
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
|
||||
)
|
||||
print("All regex tests passed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "--test-regex" in sys.argv:
|
||||
test_regex()
|
||||
else:
|
||||
sys.exit(main())
|
||||
151
tools/pre_commit/check_spdx_header.py
Normal file
151
tools/pre_commit/check_spdx_header.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SPDXStatus(Enum):
|
||||
"""SPDX header status enumeration"""
|
||||
|
||||
EMPTY = "empty" # empty __init__.py
|
||||
COMPLETE = "complete"
|
||||
MISSING_LICENSE = "missing_license" # Only has copyright line
|
||||
MISSING_COPYRIGHT = "missing_copyright" # Only has license line
|
||||
MISSING_BOTH = "missing_both" # Completely missing
|
||||
|
||||
|
||||
FULL_SPDX_HEADER = (
|
||||
"# SPDX-License-Identifier: Apache-2.0\n"
|
||||
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"
|
||||
)
|
||||
|
||||
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
|
||||
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
|
||||
|
||||
|
||||
def check_spdx_header_status(file_path):
|
||||
"""Check SPDX header status of the file"""
|
||||
with open(file_path, encoding="UTF-8") as file:
|
||||
lines = file.readlines()
|
||||
if not lines:
|
||||
# Empty file
|
||||
return SPDXStatus.EMPTY
|
||||
|
||||
# Skip shebang line
|
||||
start_idx = 0
|
||||
if lines and lines[0].startswith("#!"):
|
||||
start_idx = 1
|
||||
|
||||
has_license = False
|
||||
has_copyright = False
|
||||
|
||||
# Check all lines for SPDX headers (not just the first two)
|
||||
for i in range(start_idx, len(lines)):
|
||||
line = lines[i].strip()
|
||||
if line == LICENSE_LINE:
|
||||
has_license = True
|
||||
elif line == COPYRIGHT_LINE:
|
||||
has_copyright = True
|
||||
|
||||
# Determine status based on what we found
|
||||
if has_license and has_copyright:
|
||||
return SPDXStatus.COMPLETE
|
||||
elif has_license and not has_copyright:
|
||||
# Only has license line
|
||||
return SPDXStatus.MISSING_COPYRIGHT
|
||||
# Only has copyright line
|
||||
elif not has_license and has_copyright:
|
||||
return SPDXStatus.MISSING_LICENSE
|
||||
else:
|
||||
# Completely missing both lines
|
||||
return SPDXStatus.MISSING_BOTH
|
||||
|
||||
|
||||
def add_header(file_path, status):
|
||||
"""Add or supplement SPDX header based on status"""
|
||||
with open(file_path, "r+", encoding="UTF-8") as file:
|
||||
lines = file.readlines()
|
||||
file.seek(0, 0)
|
||||
file.truncate()
|
||||
|
||||
if status == SPDXStatus.MISSING_BOTH:
|
||||
# Completely missing, add complete header
|
||||
if lines and lines[0].startswith("#!"):
|
||||
# Preserve shebang line
|
||||
file.write(lines[0])
|
||||
file.write(FULL_SPDX_HEADER + "\n")
|
||||
file.writelines(lines[1:])
|
||||
else:
|
||||
# Add header directly
|
||||
file.write(FULL_SPDX_HEADER + "\n")
|
||||
file.writelines(lines)
|
||||
|
||||
elif status == SPDXStatus.MISSING_COPYRIGHT:
|
||||
# Only has license line, need to add copyright line
|
||||
# Find the license line and add copyright line after it
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() == LICENSE_LINE:
|
||||
# Insert copyright line after license line
|
||||
lines.insert(
|
||||
i + 1,
|
||||
f"{COPYRIGHT_LINE}\n",
|
||||
)
|
||||
break
|
||||
|
||||
file.writelines(lines)
|
||||
|
||||
elif status == SPDXStatus.MISSING_LICENSE:
|
||||
# Only has copyright line, need to add license line
|
||||
# Find the copyright line and add license line before it
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() == COPYRIGHT_LINE:
|
||||
# Insert license line before copyright line
|
||||
lines.insert(i, f"{LICENSE_LINE}\n")
|
||||
break
|
||||
file.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
files_missing_both = []
|
||||
files_missing_copyright = []
|
||||
files_missing_license = []
|
||||
|
||||
for file_path in sys.argv[1:]:
|
||||
status = check_spdx_header_status(file_path)
|
||||
|
||||
if status == SPDXStatus.MISSING_BOTH:
|
||||
files_missing_both.append(file_path)
|
||||
elif status == SPDXStatus.MISSING_COPYRIGHT:
|
||||
files_missing_copyright.append(file_path)
|
||||
elif status == SPDXStatus.MISSING_LICENSE:
|
||||
files_missing_license.append(file_path)
|
||||
else:
|
||||
continue
|
||||
|
||||
# Collect all files that need fixing
|
||||
all_files_to_fix = (
|
||||
files_missing_both + files_missing_copyright + files_missing_license
|
||||
)
|
||||
if all_files_to_fix:
|
||||
print("The following files are missing the SPDX header:")
|
||||
if files_missing_both:
|
||||
for file_path in files_missing_both:
|
||||
print(f" {file_path}")
|
||||
add_header(file_path, SPDXStatus.MISSING_BOTH)
|
||||
|
||||
if files_missing_copyright:
|
||||
for file_path in files_missing_copyright:
|
||||
print(f" {file_path}")
|
||||
add_header(file_path, SPDXStatus.MISSING_COPYRIGHT)
|
||||
if files_missing_license:
|
||||
for file_path in files_missing_license:
|
||||
print(f" {file_path}")
|
||||
add_header(file_path, SPDXStatus.MISSING_LICENSE)
|
||||
|
||||
sys.exit(1 if all_files_to_fix else 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
tools/pre_commit/check_triton_import.py
Normal file
88
tools/pre_commit/check_triton_import.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
|
||||
|
||||
# the way allowed to import triton
|
||||
ALLOWED_LINES = {
|
||||
"from vllm.triton_utils import triton",
|
||||
"from vllm.triton_utils import tl",
|
||||
"from vllm.triton_utils import tl, triton",
|
||||
}
|
||||
|
||||
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
|
||||
|
||||
|
||||
def is_allowed_file(current_file: str) -> bool:
|
||||
return current_file in ALLOWED_FILES
|
||||
|
||||
|
||||
def is_forbidden_import(line: str) -> bool:
|
||||
stripped = line.strip()
|
||||
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> list[str]:
|
||||
violations = []
|
||||
current_file = None
|
||||
current_lineno = None
|
||||
skip_allowed_file = False
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("+++ b/"):
|
||||
current_file = line[6:]
|
||||
skip_allowed_file = is_allowed_file(current_file)
|
||||
elif skip_allowed_file:
|
||||
continue
|
||||
elif line.startswith("@@"):
|
||||
match = re.search(r"\+(\d+)", line)
|
||||
if match:
|
||||
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
|
||||
elif line.startswith("+") and not line.startswith("++"):
|
||||
current_lineno += 1
|
||||
code_line = line[1:]
|
||||
if is_forbidden_import(code_line):
|
||||
violations.append(
|
||||
f"{current_file}:{current_lineno}: {code_line.strip()}"
|
||||
)
|
||||
return violations
|
||||
|
||||
|
||||
def get_diff(diff_type: str) -> str:
|
||||
if diff_type == "staged":
|
||||
return subprocess.check_output(
|
||||
["git", "diff", "--cached", "--unified=0"], text=True
|
||||
)
|
||||
elif diff_type == "unstaged":
|
||||
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown diff_type: {diff_type}")
|
||||
|
||||
|
||||
def main():
|
||||
all_violations = []
|
||||
for diff_type in ["staged", "unstaged"]:
|
||||
try:
|
||||
diff_output = get_diff(diff_type)
|
||||
violations = parse_diff(diff_output)
|
||||
all_violations.extend(violations)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
||||
|
||||
if all_violations:
|
||||
print(
|
||||
"❌ Forbidden direct `import triton` detected."
|
||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
|
||||
)
|
||||
for v in all_violations:
|
||||
print(f"❌ {v}")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
83
tools/pre_commit/enforce_regex_import.py
Normal file
83
tools/pre_commit/enforce_regex_import.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import regex as re
|
||||
|
||||
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
|
||||
ALLOWED_PATTERNS = [
|
||||
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
|
||||
re.compile(r"^\s*import\s+regex\s*$"),
|
||||
]
|
||||
|
||||
|
||||
def get_staged_python_files() -> list[str]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
|
||||
return [f for f in files if f.endswith(".py")]
|
||||
except subprocess.CalledProcessError:
|
||||
return []
|
||||
|
||||
|
||||
def is_forbidden_import(line: str) -> bool:
|
||||
line = line.strip()
|
||||
return bool(
|
||||
FORBIDDEN_PATTERNS.match(line)
|
||||
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
|
||||
)
|
||||
|
||||
|
||||
def check_file(filepath: str) -> list[tuple[int, str]]:
|
||||
violations = []
|
||||
try:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if is_forbidden_import(line):
|
||||
violations.append((line_num, line.strip()))
|
||||
except (OSError, UnicodeDecodeError):
|
||||
pass
|
||||
return violations
|
||||
|
||||
|
||||
def main() -> int:
|
||||
files = get_staged_python_files()
|
||||
if not files:
|
||||
return 0
|
||||
|
||||
total_violations = 0
|
||||
|
||||
for filepath in files:
|
||||
if not Path(filepath).exists():
|
||||
continue
|
||||
|
||||
if filepath == "setup.py":
|
||||
continue
|
||||
|
||||
violations = check_file(filepath)
|
||||
if violations:
|
||||
print(f"\n❌ {filepath}:")
|
||||
for line_num, line in violations:
|
||||
print(f" Line {line_num}: {line}")
|
||||
total_violations += 1
|
||||
|
||||
if total_violations > 0:
|
||||
print(f"\n💡 Found {total_violations} violation(s).")
|
||||
print("❌ Please replace 'import re' with 'import regex as re'")
|
||||
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
|
||||
print("✅ Allowed imports:")
|
||||
print(" - import regex as re")
|
||||
print(" - import regex") # noqa: E501
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
34
tools/pre_commit/generate_nightly_torch_test.py
Normal file
34
tools/pre_commit/generate_nightly_torch_test.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Generates specialized requirements files for nightly PyTorch testing.
|
||||
|
||||
This script reads the main test requirements input file (`requirements/test.in`)
|
||||
and splits its content into two files:
|
||||
1. `requirements/nightly_torch_test.txt`: Contains dependencies
|
||||
except PyTorch-related.
|
||||
2. `torch_nightly_test.txt`: Contains only PyTorch-related packages.
|
||||
"""
|
||||
|
||||
input_file = "requirements/test.in"
|
||||
output_file = "requirements/nightly_torch_test.txt"
|
||||
|
||||
# white list of packages that are not compatible with PyTorch nightly directly
|
||||
# with pip install. Please add your package to this list if it is not compatible
|
||||
# or make the dependency test fails.
|
||||
white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"]
|
||||
|
||||
with open(input_file) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
skip_next = False
|
||||
|
||||
for line in lines:
|
||||
if skip_next:
|
||||
if line.startswith((" ", "\t")) or line.strip() == "":
|
||||
continue
|
||||
skip_next = False
|
||||
|
||||
if any(k in line.lower() for k in white_list):
|
||||
skip_next = True
|
||||
continue
|
||||
158
tools/pre_commit/mypy.py
Executable file
158
tools/pre_commit/mypy.py
Executable file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Run mypy on changed files.
|
||||
|
||||
This script is designed to be used as a pre-commit hook. It runs mypy
|
||||
on files that have been changed. It groups files into different mypy calls
|
||||
based on their directory to avoid import following issues.
|
||||
|
||||
Usage:
|
||||
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
|
||||
|
||||
Args:
|
||||
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
|
||||
"silent" for the main group of files.
|
||||
python_version: Python version to use (e.g., "3.10") or "local" to use
|
||||
the local Python version.
|
||||
changed_files: List of changed files to check.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
FILES = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/distributed",
|
||||
"vllm/engine",
|
||||
"vllm/entrypoints",
|
||||
"vllm/executor",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/plugins",
|
||||
"vllm/tokenizers",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
"vllm/utils",
|
||||
"vllm/worker",
|
||||
"vllm/v1/core",
|
||||
"vllm/v1/engine",
|
||||
"vllm/v1/executor",
|
||||
"vllm/v1/metrics",
|
||||
"vllm/v1/pool",
|
||||
"vllm/v1/sample",
|
||||
"vllm/v1/worker",
|
||||
]
|
||||
|
||||
# After fixing errors resulting from changing follow_imports
|
||||
# from "skip" to "silent", move the following directories to FILES
|
||||
SEPARATE_GROUPS = [
|
||||
"tests",
|
||||
# v0 related
|
||||
"vllm/attention",
|
||||
"vllm/compilation",
|
||||
"vllm/lora",
|
||||
"vllm/model_executor",
|
||||
# v1 related
|
||||
"vllm/v1/attention",
|
||||
"vllm/v1/kv_offload",
|
||||
"vllm/v1/spec_decode",
|
||||
"vllm/v1/structured_output",
|
||||
]
|
||||
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
EXCLUDE = [
|
||||
"vllm/engine/arg_utils.py",
|
||||
"vllm/model_executor/parallel_utils",
|
||||
"vllm/model_executor/models",
|
||||
"vllm/model_executor/layers/fla/ops",
|
||||
# Ignore triton kernels in ops.
|
||||
"vllm/attention/ops",
|
||||
]
|
||||
|
||||
|
||||
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group changed files into different mypy calls.
|
||||
|
||||
Args:
|
||||
changed_files: List of changed files.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping file group names to lists of changed files.
|
||||
"""
|
||||
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
|
||||
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
|
||||
file_groups = {"": []}
|
||||
file_groups.update({k: [] for k in SEPARATE_GROUPS})
|
||||
for changed_file in changed_files:
|
||||
# Skip files which should be ignored completely
|
||||
if exclude_pattern.match(changed_file):
|
||||
continue
|
||||
# Group files by mypy call
|
||||
if files_pattern.match(changed_file):
|
||||
file_groups[""].append(changed_file)
|
||||
continue
|
||||
else:
|
||||
for directory in SEPARATE_GROUPS:
|
||||
if re.match(f"^{directory}.*", changed_file):
|
||||
file_groups[directory].append(changed_file)
|
||||
break
|
||||
return file_groups
|
||||
|
||||
|
||||
def mypy(
|
||||
targets: list[str],
|
||||
python_version: str | None,
|
||||
follow_imports: str | None,
|
||||
file_group: str,
|
||||
) -> int:
|
||||
"""
|
||||
Run mypy on the given targets.
|
||||
|
||||
Args:
|
||||
targets: List of files or directories to check.
|
||||
python_version: Python version to use (e.g., "3.10") or None to use
|
||||
the default mypy version.
|
||||
follow_imports: Value for the --follow-imports option or None to use
|
||||
the default mypy behavior.
|
||||
file_group: The file group name for logging purposes.
|
||||
|
||||
Returns:
|
||||
The return code from mypy.
|
||||
"""
|
||||
args = ["mypy"]
|
||||
if python_version is not None:
|
||||
args += ["--python-version", python_version]
|
||||
if follow_imports is not None:
|
||||
args += ["--follow-imports", follow_imports]
|
||||
print(f"$ {' '.join(args)} {file_group}")
|
||||
return subprocess.run(args + targets, check=False).returncode
|
||||
|
||||
|
||||
def main():
|
||||
ci = sys.argv[1] == "1"
|
||||
python_version = sys.argv[2]
|
||||
file_groups = group_files(sys.argv[3:])
|
||||
|
||||
if python_version == "local":
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
returncode = 0
|
||||
for file_group, changed_files in file_groups.items():
|
||||
follow_imports = None if ci and file_group == "" else "skip"
|
||||
if changed_files:
|
||||
returncode |= mypy(
|
||||
changed_files, python_version, follow_imports, file_group
|
||||
)
|
||||
return returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
15
tools/pre_commit/png-lint.sh
Executable file
15
tools/pre_commit/png-lint.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Ensure that *.excalidraw.png files have the excalidraw metadata
|
||||
# embedded in them. This ensures they can be loaded back into
|
||||
# the tool and edited in the future.
|
||||
|
||||
find . -iname '*.excalidraw.png' | while read -r file; do
|
||||
if git check-ignore -q "$file"; then
|
||||
continue
|
||||
fi
|
||||
if ! grep -q "excalidraw+json" "$file"; then
|
||||
echo "$file was not exported from excalidraw with 'Embed Scene' enabled."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
22
tools/pre_commit/shellcheck.sh
Executable file
22
tools/pre_commit/shellcheck.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
scversion="stable"
|
||||
|
||||
if [ -d "shellcheck-${scversion}" ]; then
|
||||
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
|
||||
fi
|
||||
|
||||
if ! [ -x "$(command -v shellcheck)" ]; then
|
||||
if [ "$(uname -s)" != "Linux" ] || [ "$(uname -m)" != "x86_64" ]; then
|
||||
echo "Please install shellcheck: https://github.com/koalaman/shellcheck?tab=readme-ov-file#installing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# automatic local install if linux x86_64
|
||||
wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv
|
||||
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
|
||||
fi
|
||||
|
||||
# TODO - fix warnings in .buildkite/scripts/hardware_ci/run-amd-test.sh
|
||||
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/scripts/hardware_ci/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'
|
||||
81
tools/pre_commit/update-dockerfile-graph.sh
Executable file
81
tools/pre_commit/update-dockerfile-graph.sh
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
# Update Dockerfile dependency graph when docker/Dockerfile changes.
|
||||
# This script is designed to be used as a pre-commit hook.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Accept file paths as arguments
|
||||
FILES=("$@")
|
||||
|
||||
# Check if docker/Dockerfile is among the provided files
|
||||
if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then
|
||||
echo "docker/Dockerfile has changed, attempting to update dependency graph..."
|
||||
|
||||
# Check if Docker is installed and running
|
||||
if ! command -v docker &> /dev/null; then
|
||||
echo "Warning: Docker command not found. Skipping Dockerfile graph update."
|
||||
echo "Please install Docker to automatically update the graph: https://docs.docker.com/get-docker/"
|
||||
exit 0
|
||||
fi
|
||||
if ! docker info &> /dev/null; then
|
||||
echo "Warning: Docker daemon is not running. Skipping Dockerfile graph update."
|
||||
echo "Please start Docker to automatically update the graph."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Define the target file path
|
||||
TARGET_GRAPH_FILE="docs/assets/contributing/dockerfile-stages-dependency.png"
|
||||
|
||||
# Ensure target directory exists
|
||||
mkdir -p "$(dirname "$TARGET_GRAPH_FILE")"
|
||||
|
||||
# Store old image hash in a variable if the file exists
|
||||
OLD_HASH=""
|
||||
if [ -f "$TARGET_GRAPH_FILE" ]; then
|
||||
OLD_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
|
||||
fi
|
||||
|
||||
# Generate Dockerfile graph
|
||||
echo "Running dockerfilegraph tool..."
|
||||
docker run \
|
||||
--rm \
|
||||
--user "$(id -u):$(id -g)" \
|
||||
--workdir /workspace \
|
||||
--volume "$(pwd)":/workspace \
|
||||
ghcr.io/patrickhoefler/dockerfilegraph:alpine \
|
||||
--output png \
|
||||
--dpi 200 \
|
||||
--max-label-length 50 \
|
||||
--filename docker/Dockerfile \
|
||||
--legend
|
||||
|
||||
echo "Finding generated PNG file..."
|
||||
# Check for Dockerfile.png in the root directory (most likely location)
|
||||
if [ -f "./Dockerfile.png" ]; then
|
||||
echo "Found generated file at: ./Dockerfile.png"
|
||||
mv "./Dockerfile.png" "$TARGET_GRAPH_FILE"
|
||||
else
|
||||
# Try to find it elsewhere
|
||||
DOCKERFILE_PNG=$(find . -name "Dockerfile.png" -type f | head -1)
|
||||
|
||||
if [ -n "$DOCKERFILE_PNG" ]; then
|
||||
echo "Found generated file at: $DOCKERFILE_PNG"
|
||||
mv "$DOCKERFILE_PNG" "$TARGET_GRAPH_FILE"
|
||||
else
|
||||
echo "Error: Could not find the generated PNG file"
|
||||
find . -name "*.png" -type f -mmin -5
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if the graph has changed
|
||||
NEW_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
|
||||
if [ "$NEW_HASH" != "$OLD_HASH" ]; then
|
||||
echo "Graph has changed. Please stage the updated file: $TARGET_GRAPH_FILE"
|
||||
exit 1
|
||||
else
|
||||
echo "No changes in graph detected."
|
||||
fi
|
||||
fi
|
||||
|
||||
exit 0
|
||||
171
tools/pre_commit/validate_config.py
Normal file
171
tools/pre_commit/validate_config.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Ensures all fields in a config dataclass have default values
|
||||
and that each field has a docstring.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
from itertools import pairwise
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
Adapted from https://davidism.com/attribute-docstrings/
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (
|
||||
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)
|
||||
):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ConfigValidator(ast.NodeVisitor):
|
||||
def __init__(self): ...
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# Validate class with both @config and @dataclass decorators
|
||||
decorators = [
|
||||
id
|
||||
for d in node.decorator_list
|
||||
if (
|
||||
isinstance(d, ast.Name)
|
||||
and ((id := d.id) == "config" or id == "dataclass")
|
||||
)
|
||||
or (
|
||||
isinstance(d, ast.Call)
|
||||
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
|
||||
)
|
||||
]
|
||||
|
||||
if set(decorators) == {"config", "dataclass"}:
|
||||
validate_class(node)
|
||||
elif set(decorators) == {"config"}:
|
||||
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def validate_class(class_node: ast.ClassDef):
|
||||
attr_docs = get_attr_docs(class_node)
|
||||
|
||||
for stmt in class_node.body:
|
||||
# A field is defined as a class variable that has a type annotation.
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
# Skip ClassVar and InitVar
|
||||
# see https://docs.python.org/3/library/dataclasses.html#class-variables
|
||||
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
|
||||
if (
|
||||
isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id in {"ClassVar", "InitVar"}
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(stmt.target, ast.Name):
|
||||
field_name = stmt.target.id
|
||||
if stmt.value is None:
|
||||
fail(
|
||||
f"Field '{field_name}' in {class_node.name} must have "
|
||||
"a default value.",
|
||||
stmt,
|
||||
)
|
||||
|
||||
if field_name not in attr_docs:
|
||||
fail(
|
||||
f"Field '{field_name}' in {class_node.name} must have "
|
||||
"a docstring.",
|
||||
stmt,
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id == "Union"
|
||||
and isinstance(stmt.annotation.slice, ast.Tuple)
|
||||
):
|
||||
args = stmt.annotation.slice.elts
|
||||
literal_args = [
|
||||
arg
|
||||
for arg in args
|
||||
if isinstance(arg, ast.Subscript)
|
||||
and isinstance(arg.value, ast.Name)
|
||||
and arg.value.id == "Literal"
|
||||
]
|
||||
if len(literal_args) > 1:
|
||||
fail(
|
||||
f"Field '{field_name}' in {class_node.name} must "
|
||||
"use a single "
|
||||
"Literal type. Please use 'Literal[Literal1, "
|
||||
"Literal2]' instead of 'Union[Literal1, Literal2]'"
|
||||
".",
|
||||
stmt,
|
||||
)
|
||||
|
||||
|
||||
def validate_ast(tree: ast.stmt):
|
||||
ConfigValidator().visit(tree)
|
||||
|
||||
|
||||
def validate_file(file_path: str):
|
||||
try:
|
||||
print(f"Validating {file_path} config dataclasses ", end="")
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source, filename=file_path)
|
||||
validate_ast(tree)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
raise SystemExit(1) from e
|
||||
else:
|
||||
print("✅")
|
||||
|
||||
|
||||
def fail(message: str, node: ast.stmt):
|
||||
raise ValueError(f"❌ line({node.lineno}): {message}")
|
||||
|
||||
|
||||
def main():
|
||||
for filename in sys.argv[1:]:
|
||||
# Only run for Python files in vllm/ or tests/
|
||||
if not re.match(r"^(vllm|tests)/.*\.py$", filename):
|
||||
continue
|
||||
# Only run if the file contains @config
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
if "@config" in f.read():
|
||||
validate_file(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user