Sync from v0.13
This commit is contained in:
88
tools/pre_commit/check_triton_import.py
Normal file
88
tools/pre_commit/check_triton_import.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
|
||||
|
||||
# the way allowed to import triton
|
||||
ALLOWED_LINES = {
|
||||
"from vllm.triton_utils import triton",
|
||||
"from vllm.triton_utils import tl",
|
||||
"from vllm.triton_utils import tl, triton",
|
||||
}
|
||||
|
||||
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
|
||||
|
||||
|
||||
def is_allowed_file(current_file: str) -> bool:
|
||||
return current_file in ALLOWED_FILES
|
||||
|
||||
|
||||
def is_forbidden_import(line: str) -> bool:
|
||||
stripped = line.strip()
|
||||
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> list[str]:
|
||||
violations = []
|
||||
current_file = None
|
||||
current_lineno = None
|
||||
skip_allowed_file = False
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("+++ b/"):
|
||||
current_file = line[6:]
|
||||
skip_allowed_file = is_allowed_file(current_file)
|
||||
elif skip_allowed_file:
|
||||
continue
|
||||
elif line.startswith("@@"):
|
||||
match = re.search(r"\+(\d+)", line)
|
||||
if match:
|
||||
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
|
||||
elif line.startswith("+") and not line.startswith("++"):
|
||||
current_lineno += 1
|
||||
code_line = line[1:]
|
||||
if is_forbidden_import(code_line):
|
||||
violations.append(
|
||||
f"{current_file}:{current_lineno}: {code_line.strip()}"
|
||||
)
|
||||
return violations
|
||||
|
||||
|
||||
def get_diff(diff_type: str) -> str:
|
||||
if diff_type == "staged":
|
||||
return subprocess.check_output(
|
||||
["git", "diff", "--cached", "--unified=0"], text=True
|
||||
)
|
||||
elif diff_type == "unstaged":
|
||||
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown diff_type: {diff_type}")
|
||||
|
||||
|
||||
def main():
|
||||
all_violations = []
|
||||
for diff_type in ["staged", "unstaged"]:
|
||||
try:
|
||||
diff_output = get_diff(diff_type)
|
||||
violations = parse_diff(diff_output)
|
||||
all_violations.extend(violations)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
||||
|
||||
if all_violations:
|
||||
print(
|
||||
"❌ Forbidden direct `import triton` detected."
|
||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
|
||||
)
|
||||
for v in all_violations:
|
||||
print(f"❌ {v}")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user