Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

14
tools/check_repo.sh Executable file
View File

@@ -0,0 +1,14 @@
#!/bin/bash
# Checks whether the repo is clean and whether tags are available (necessary to correctly produce vllm version at build time)
if ! git diff --quiet; then
echo "Repo is dirty" >&2
exit 1
fi
if ! git describe --tags; then
echo "No tags are present. Is this a shallow clone? git fetch --unshallow --tags" >&2
exit 1
fi

View File

@@ -0,0 +1,28 @@
# Expert parallel kernels
Large-scale cluster-level expert parallel, as described in the [DeepSeek-V3 Technical Report](http://arxiv.org/abs/2412.19437), is an efficient way to deploy sparse MoE models with many experts. However, such deployment requires many components beyond a normal Python package, including system package support and system driver support. It is impossible to bundle all these components into a Python package.
Here we break down the requirements in 2 steps:
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
Step 2 is necessary for multi-node deployment.
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
## Usage
```bash
# for hopper
TORCH_CUDA_ARCH_LIST="9.0" bash install_python_libraries.sh
# for blackwell
TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh
```
Additional step for multi-node deployment:
```bash
sudo bash configure_system_drivers.sh # update-initramfs can take several minutes
sudo reboot # Reboot is required to load the new driver
```

View File

@@ -0,0 +1,17 @@
set -ex
# turn on IBGDA
echo 'options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"' | tee -a /etc/modprobe.d/nvidia.conf
if command -v update-initramfs &> /dev/null; then
# for Debian/Ubuntu
sudo update-initramfs -u
elif command -v dracut &> /dev/null; then
# for Fedora/CentOS
sudo dracut --force
else
echo "No supported initramfs update tool found."
exit 1
fi
echo "Please reboot the system to apply the changes"

View File

@@ -0,0 +1,92 @@
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
From: Yongji Wu <wuyongji317@gmail.com>
Date: Tue, 20 May 2025 13:41:12 -0700
Subject: [PATCH] fix reinit issues due to states not cleaned up
fix double free
---
src/host/init/init.cu | 10 ++++++++++
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
3 files changed, 30 insertions(+)
diff --git a/src/host/init/init.cu b/src/host/init/init.cu
index b1c5dbf..1fecb4b 100644
--- a/src/host/init/init.cu
+++ b/src/host/init/init.cu
@@ -43,6 +43,8 @@
#include "internal/host/nvshmemi_types.h"
#include "internal/host/shared_memory.h"
#include "internal/host/nvshmemi_symmetric_heap.hpp"
+// eep-dev
+#include "internal/host/nvshmemi_mem_transport.hpp"
extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
static std::map<void *, int> registered_device_states;
@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
/* Multi-init Multi-fini*/
nvshmemi_state = NULL;
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
+
+ // eep-dev
+ nvshmemi_mem_p2p_transport::destroy_instance();
+ nvshmemi_mem_remote_transport::destroy_instance();
+ free(nvshmemi_default_session);
+ nvshmemi_default_session = nullptr;
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
+
nvshmemi_is_device_state_ready = false;
} else
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
index 2495844..e4f408a 100644
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final {
return p2p_objref_;
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (p2p_objref_ != nullptr) {
+ delete p2p_objref_;
+ p2p_objref_ = nullptr;
+ }
+ }
void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);
@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final {
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (remote_objref_ != nullptr) {
+ delete remote_objref_;
+ remote_objref_ = nullptr;
+ }
+ }
+
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
/* On-demand registration and release of memory */
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
index a1fa748..788fa96 100644
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
// Discover the network for bootstrap, if not done previously.
// This code needs to be stateful to be able to be called multiple times by the caller
BOOTSTRAP_CHECK(bootstrap_net_init());
+ // eep-dev
+ if (handle->pre_init_ops != nullptr) {
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
+ handle->pre_init_ops = nullptr;
+ }
if (handle->pre_init_ops == nullptr) {
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
--
2.43.0

View File

@@ -0,0 +1,86 @@
#!/bin/bash
set -ex
# Default workspace directory
WORKSPACE=$(pwd)/eep_kernels_workspace
INSTALL_NVSHMEM=true
# Parse command line arguments
while getopts "w:n" opt; do
case $opt in
w)
WORKSPACE="$OPTARG"
;;
n)
INSTALL_NVSHMEM=false
;;
\?)
echo "Invalid option: -$OPTARG" >&2
exit 1
;;
esac
done
if [ ! -d "$WORKSPACE" ]; then
mkdir -p $WORKSPACE
fi
# install dependencies if not installed
pip3 install cmake torch ninja
# build nvshmem
pushd $WORKSPACE
# Reset NVSHMEM build if requested
if [ "$INSTALL_NVSHMEM" = true ]; then
mkdir -p nvshmem_src
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
pushd nvshmem_src
wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
git init
git apply -vvv nvshmem.patch
git apply --reject --whitespace=fix ../../eep_nvshmem.patch
else
pushd nvshmem_src
fi
# assume CUDA_HOME is set correctly
if [ -z "$CUDA_HOME" ]; then
echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
exit 1
fi
# disable all features except IBGDA
export NVSHMEM_IBGDA_SUPPORT=1
export NVSHMEM_SHMEM_SUPPORT=0
export NVSHMEM_UCX_SUPPORT=0
export NVSHMEM_USE_NCCL=0
export NVSHMEM_PMIX_SUPPORT=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_IBRC_SUPPORT=0
export NVSHMEM_BUILD_TESTS=0
export NVSHMEM_BUILD_EXAMPLES=0
export NVSHMEM_MPI_SUPPORT=0
export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
export NVSHMEM_BUILD_TXZ_PACKAGE=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
cmake --build $WORKSPACE/nvshmem_build/ --target install
popd
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
# build and install pplx, require pytorch installed
pushd $WORKSPACE
git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env bash
set -ex
# usage: ./install_python_libraries.sh [options]
# --workspace <dir> workspace directory (default: ./ep_kernels_workspace)
# --mode <mode> "install" (default) or "wheel"
# --pplx-ref <commit> pplx-kernels commit hash
# --deepep-ref <commit> DeepEP commit hash
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
NVSHMEM_VER=3.3.24 # Suppports both CUDA 12 and 13
WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
MODE=${MODE:-install}
CUDA_VERSION_MAJOR=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--workspace)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --workspace requires an argument." >&2
exit 1
fi
WORKSPACE="$2"
shift 2
;;
--mode)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --mode requires an argument." >&2
exit 1
fi
MODE="$2"
shift 2
;;
--pplx-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --pplx-ref requires an argument." >&2
exit 1
fi
PPLX_COMMIT_HASH="$2"
shift 2
;;
--deepep-ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --deepep-ref requires an argument." >&2
exit 1
fi
DEEPEP_COMMIT_HASH="$2"
shift 2
;;
*)
echo "Error: Unknown argument '$1'" >&2
exit 1
;;
esac
done
mkdir -p "$WORKSPACE"
WHEEL_DIR="$WORKSPACE/dist"
mkdir -p "$WHEEL_DIR"
pushd "$WORKSPACE"
# install dependencies if not installed
if [ -z "$VIRTUAL_ENV" ]; then
uv pip install --system cmake torch ninja
else
uv pip install cmake torch ninja
fi
# fetch nvshmem
ARCH=$(uname -m)
case "${ARCH,,}" in
x86_64|amd64)
NVSHMEM_SUBDIR="linux-x86_64"
;;
aarch64|arm64)
NVSHMEM_SUBDIR="linux-sbsa"
;;
*)
echo "Unsupported architecture: ${ARCH}" >&2
exit 1
;;
esac
NVSHMEM_FILE="libnvshmem-${NVSHMEM_SUBDIR}-${NVSHMEM_VER}_cuda${CUDA_VERSION_MAJOR}-archive.tar.xz"
NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}"
pushd "$WORKSPACE"
echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
tar -xf "${NVSHMEM_FILE}"
mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
rm -f "${NVSHMEM_FILE}"
rm -rf nvshmem/lib/bin nvshmem/lib/share
popd
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem/lib/cmake:$CMAKE_PREFIX_PATH
is_git_dirty() {
local dir=$1
pushd "$dir" > /dev/null
if [ -d ".git" ] && [ -n "$(git status --porcelain 3>/dev/null)" ]; then
popd > /dev/null
return 0
else
popd > /dev/null
return 1
fi
}
clone_repo() {
local repo_url=$1
local dir_name=$2
local key_file=$3
local commit_hash=$4
if [ -d "$dir_name" ]; then
if is_git_dirty "$dir_name"; then
echo "$dir_name directory is dirty, skipping clone"
elif [ ! -d "$dir_name/.git" ] || [ ! -f "$dir_name/$key_file" ]; then
echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning"
rm -rf "$dir_name"
git clone "$repo_url"
if [ -n "$commit_hash" ]; then
cd "$dir_name"
git checkout "$commit_hash"
cd ..
fi
else
echo "$dir_name directory exists and appears complete"
fi
else
git clone "$repo_url"
if [ -n "$commit_hash" ]; then
cd "$dir_name"
git checkout "$commit_hash"
cd ..
fi
fi
}
do_build() {
local repo=$1
local name=$2
local key=$3
local commit=$4
local extra_env=$5
pushd "$WORKSPACE"
clone_repo "$repo" "$name" "$key" "$commit"
cd "$name"
# DeepEP CUDA 13 patch
if [[ "$name" == "DeepEP" && "${CUDA_VERSION_MAJOR}" -ge 13 ]]; then
sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
fi
if [ "$MODE" = "install" ]; then
echo "Installing $name into environment"
eval "$extra_env" uv pip install --no-build-isolation -vvv .
else
echo "Building $name wheel into $WHEEL_DIR"
eval "$extra_env" uv build --wheel --no-build-isolation -vvv --out-dir "$WHEEL_DIR" .
fi
popd
}
# build pplx-kernels
do_build \
"https://github.com/ppl-ai/pplx-kernels" \
"pplx-kernels" \
"setup.py" \
"$PPLX_COMMIT_HASH" \
""
# build DeepEP
do_build \
"https://github.com/deepseek-ai/DeepEP" \
"DeepEP" \
"setup.py" \
"$DEEPEP_COMMIT_HASH" \
"export NVSHMEM_DIR=$WORKSPACE/nvshmem; "
if [ "$MODE" = "wheel" ]; then
echo "All wheels written to $WHEEL_DIR"
ls -l "$WHEEL_DIR"
fi

63
tools/flashinfer-build.sh Executable file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env bash
# This script is used to build FlashInfer wheels with AOT kernels
set -ex
# FlashInfer configuration
FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}"
CUDA_VERSION="${CUDA_VERSION}"
BUILD_WHEEL="${BUILD_WHEEL:-true}"
if [[ -z "${FLASHINFER_GIT_REF}" ]]; then
echo "❌ FLASHINFER_GIT_REF must be specified" >&2
exit 1
fi
if [[ -z "${CUDA_VERSION}" ]]; then
echo "❌ CUDA_VERSION must be specified" >&2
exit 1
fi
echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}"
# Clone FlashInfer
git clone --depth 1 --recursive --shallow-submodules \
--branch ${FLASHINFER_GIT_REF} \
${FLASHINFER_GIT_REPO} flashinfer
# Set CUDA arch list based on CUDA version
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
if [[ "${CUDA_VERSION}" == 11.* ]]; then
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
else
# CUDA 12.8+ supports 10.0a and 12.0
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi
echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
pushd flashinfer
# Make sure the wheel is built for the correct CUDA version
export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# Build AOT kernels
export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
python3 -m flashinfer.aot
if [[ "${BUILD_WHEEL}" == "true" ]]; then
# Build wheel for distribution
uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist .
echo "✅ FlashInfer wheel built successfully in flashinfer-dist/"
else
# Install directly (for Dockerfile)
uv pip install --system --no-build-isolation --force-reinstall .
echo "✅ FlashInfer installed successfully"
fi
popd
# Cleanup
rm -rf flashinfer

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import multiprocessing
import os
import sys
from shutil import which
try:
# Try to get CUDA_HOME from PyTorch installation, which is the
# most reliable source of truth for vLLM's build.
from torch.utils.cpp_extension import CUDA_HOME
except ImportError:
print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.")
CUDA_HOME = os.environ.get("CUDA_HOME")
def get_python_executable():
"""Get the current Python executable, which is used to run this script."""
return sys.executable
def get_cpu_cores():
"""Get the number of CPU cores."""
return multiprocessing.cpu_count()
def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False):
"""Generates the CMakeUserPresets.json file."""
print("Attempting to detect your system configuration...")
# Detect NVCC
nvcc_path = None
if CUDA_HOME:
prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc")
if os.path.exists(prospective_path):
nvcc_path = prospective_path
print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}")
if not nvcc_path:
nvcc_path = which("nvcc")
if nvcc_path:
print(f"Found nvcc in PATH: {nvcc_path}")
if not nvcc_path:
nvcc_path_input = input(
"Could not automatically find 'nvcc'. Please provide the full "
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): "
)
nvcc_path = nvcc_path_input.strip()
print(f"Using NVCC path: {nvcc_path}")
# Detect Python executable
python_executable = get_python_executable()
if python_executable:
print(f"Found Python via sys.executable: {python_executable}")
else:
python_executable_prompt = (
"Could not automatically find Python executable. Please provide "
"the full path to your Python executable for vLLM development "
"(typically from your virtual environment, e.g., "
"/home/user/venvs/vllm/bin/python): "
)
python_executable = input(python_executable_prompt).strip()
if not python_executable:
raise ValueError(
"Could not determine Python executable. Please provide it manually."
)
print(f"Using Python executable: {python_executable}")
# Get CPU cores
cpu_cores = get_cpu_cores()
nvcc_threads = min(4, cpu_cores)
cmake_jobs = max(1, cpu_cores // nvcc_threads)
print(
f"Detected {cpu_cores} CPU cores. "
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}."
)
# Get vLLM project root (assuming this script is in vllm/tools/)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
print(f"VLLM project root detected as: {project_root}")
# Ensure python_executable path is absolute or resolvable
if not os.path.isabs(python_executable) and which(python_executable):
python_executable = os.path.abspath(which(python_executable))
elif not os.path.isabs(python_executable):
print(
f"Warning: Python executable '{python_executable}' is not an "
"absolute path and not found in PATH. CMake might not find it."
)
cache_variables = {
"CMAKE_CUDA_COMPILER": nvcc_path,
"CMAKE_BUILD_TYPE": "Release",
"VLLM_PYTHON_EXECUTABLE": python_executable,
"CMAKE_INSTALL_PREFIX": "${sourceDir}",
"CMAKE_CUDA_FLAGS": "",
"NVCC_THREADS": str(nvcc_threads),
}
# Detect compiler cache
if which("sccache"):
print("Using sccache for compiler caching.")
for launcher in ("C", "CXX", "CUDA", "HIP"):
cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "sccache"
elif which("ccache"):
print("Using ccache for compiler caching.")
for launcher in ("C", "CXX", "CUDA", "HIP"):
cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "ccache"
else:
print("No compiler cache ('ccache' or 'sccache') found.")
configure_preset = {
"name": "release",
"binaryDir": "${sourceDir}/cmake-build-release",
"cacheVariables": cache_variables,
}
if which("ninja"):
print("Using Ninja generator.")
configure_preset["generator"] = "Ninja"
cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}"
else:
print("Ninja not found, using default generator. Build may be slower.")
presets = {
"version": 6,
# Keep in sync with CMakeLists.txt and requirements/build.txt
"cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1},
"configurePresets": [configure_preset],
"buildPresets": [
{
"name": "release",
"configurePreset": "release",
"jobs": cmake_jobs,
}
],
}
output_file_path = os.path.join(project_root, output_path)
if os.path.exists(output_file_path):
if force_overwrite:
print(f"Overwriting existing file '{output_file_path}'")
else:
overwrite = (
input(f"'{output_file_path}' already exists. Overwrite? (y/N): ")
.strip()
.lower()
)
if overwrite != "y":
print("Generation cancelled.")
return
try:
with open(output_file_path, "w") as f:
json.dump(presets, f, indent=4)
print(f"Successfully generated '{output_file_path}'")
print("\nTo use this preset:")
print(f"1. Ensure you are in the vLLM root directory: cd {project_root}")
print("2. Initialize CMake: cmake --preset release")
print("3. Build+install: cmake --build --preset release --target install")
except OSError as e:
print(f"Error writing file: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--force-overwrite",
action="store_true",
help="Force overwrite existing CMakeUserPresets.json without prompting",
)
args = parser.parse_args()
generate_presets(force_overwrite=args.force_overwrite)

124
tools/install_deepgemm.sh Executable file
View File

@@ -0,0 +1,124 @@
#!/bin/bash
# Script to build and/or install DeepGEMM from source
# Default: build and install immediately
# Optional: build wheels to a directory for later installation (useful in multi-stage builds)
set -e
# Default values
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="594953acce41793ae00a1233eb516044d604bcb6"
WHEEL_DIR=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --ref requires an argument." >&2
exit 1
fi
DEEPGEMM_GIT_REF="$2"
shift 2
;;
--cuda-version)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --cuda-version requires an argument." >&2
exit 1
fi
CUDA_VERSION="$2"
shift 2
;;
--wheel-dir)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --wheel-dir requires a directory path." >&2
exit 1
fi
WHEEL_DIR="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
echo " --wheel-dir PATH If set, build wheel into PATH but do not install"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1" >&2
exit 1
;;
esac
done
# Auto-detect CUDA version if not provided
if [ -z "$CUDA_VERSION" ]; then
if command -v nvcc >/dev/null 2>&1; then
CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p')
echo "Auto-detected CUDA version: $CUDA_VERSION"
else
echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version"
exit 1
fi
fi
# Extract major and minor version numbers
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
# Check CUDA version requirement
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
echo "Skipping DeepGEMM build/installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
exit 0
fi
echo "Preparing DeepGEMM build..."
echo "Repository: $DEEPGEMM_GIT_REPO"
echo "Reference: $DEEPGEMM_GIT_REF"
# Create a temporary directory for the build
INSTALL_DIR=$(mktemp -d)
trap 'rm -rf "$INSTALL_DIR"' EXIT
# Clone the repository
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
pushd "$INSTALL_DIR/deepgemm"
# Checkout the specific reference
git checkout "$DEEPGEMM_GIT_REF"
# Clean previous build artifacts
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
rm -rf build dist *.egg-info
# Build wheel
echo "🏗️ Building DeepGEMM wheel..."
python3 setup.py bdist_wheel
# If --wheel-dir was specified, copy wheels there and exit
if [ -n "$WHEEL_DIR" ]; then
mkdir -p "$WHEEL_DIR"
cp dist/*.whl "$WHEEL_DIR"/
echo "✅ Wheel built and copied to $WHEEL_DIR"
popd
exit 0
fi
# Default behaviour: install built wheel
if command -v uv >/dev/null 2>&1; then
echo "Installing DeepGEMM wheel using uv..."
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
uv pip install --system dist/*.whl
else
uv pip install dist/*.whl
fi
else
echo "Installing DeepGEMM wheel using pip..."
python3 -m pip install dist/*.whl
fi
popd
echo "✅ DeepGEMM installation completed successfully"

54
tools/install_gdrcopy.sh Executable file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
# Usage: install_gdrcopy.sh <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch>
# uuarch must be "x64" or "aarch64"
# Optional: set GDRCOPY_VERSION to override the libgdrapi package version (default: 2.5.1-1)
# Requires: curl, apt-get, root privileges
if [[ $(id -u) -ne 0 ]]; then
echo "Must be run as root" >&2
exit 1
fi
if [[ $# -ne 3 ]]; then
echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2
exit 1
fi
OS_VER="$1"
CUDA_VER="$2"
UUARCH_RAW="$3"
# Normalize/validate arch
case "${UUARCH_RAW,,}" in
aarch64|arm64)
URL_ARCH="aarch64"
DEB_ARCH="arm64"
;;
x64|x86_64|amd64)
URL_ARCH="x64"
DEB_ARCH="amd64"
;;
*)
echo "Unsupported uuarch: ${UUARCH_RAW}. Use 'x64' or 'aarch64'." >&2
exit 1
;;
esac
OS_VER_LOWER="$(tr '[:upper:]' '[:lower:]' <<<"$OS_VER")"
GDRCOPY_PKG_VER="${GDRCOPY_VERSION:-2.5.1-1}"
DEB_NAME="libgdrapi_${GDRCOPY_PKG_VER}_${DEB_ARCH}.${OS_VER}.deb"
BASE_URL="https://developer.download.nvidia.com/compute/redist/gdrcopy"
URL="${BASE_URL}/CUDA%20${CUDA_VER}/${OS_VER_LOWER}/${URL_ARCH}/${DEB_NAME}"
echo "Downloading: ${URL}"
TMPDIR="$(mktemp -d)"
trap 'rm -rf "${TMPDIR}"' EXIT
curl -fSL "${URL}" -o "${TMPDIR}/${DEB_NAME}"
export DEBIAN_FRONTEND=noninteractive
apt-get update
apt-get install -y "${TMPDIR}/${DEB_NAME}"
apt-get clean
rm -rf /var/lib/apt/lists/*
echo "Installed ${DEB_NAME}"

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# install_prerequisites.py
import argparse
import glob
import json
import os
import subprocess
import sys
import urllib.request
# --- Configuration ---
WHEELS_CACHE_HOME = os.environ.get("WHEELS_CACHE_HOME", "/tmp/wheels_cache")
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
UCX_DIR = os.path.join("/tmp", "ucx_source")
NIXL_DIR = os.path.join("/tmp", "nixl_source")
UCX_INSTALL_DIR = os.path.join("/tmp", "ucx_install")
UCX_REPO_URL = "https://github.com/openucx/ucx.git"
NIXL_REPO_URL = "https://github.com/ai-dynamo/nixl.git"
# --- Helper Functions ---
def get_latest_nixl_version():
"""Helper function to get latest release version of NIXL"""
try:
nixl_release_url = "https://api.github.com/repos/ai-dynamo/nixl/releases/latest"
with urllib.request.urlopen(nixl_release_url) as response:
data = json.load(response)
return data.get("tag_name", "0.7.0")
except Exception:
return "0.7.0"
NIXL_VERSION = os.environ.get("NIXL_VERSION", get_latest_nixl_version())
def run_command(command, cwd=".", env=None):
"""Helper function to run a shell command and check for errors."""
print(f"--> Running command: {' '.join(command)} in '{cwd}'", flush=True)
subprocess.check_call(command, cwd=cwd, env=env)
def is_pip_package_installed(package_name):
"""Checks if a package is installed via pip without raising an exception."""
result = subprocess.run(
[sys.executable, "-m", "pip", "show", package_name],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return result.returncode == 0
def find_nixl_wheel_in_cache(cache_dir):
"""Finds a nixl wheel file in the specified cache directory."""
# The repaired wheel will have a 'manylinux' tag, but this glob still works.
search_pattern = os.path.join(cache_dir, f"nixl*{NIXL_VERSION}*.whl")
wheels = glob.glob(search_pattern)
if wheels:
# Sort to get the most recent/highest version if multiple exist
wheels.sort()
return wheels[-1]
return None
def install_system_dependencies():
"""Installs required system packages using apt-get if run as root."""
if os.geteuid() != 0:
print("\n---", flush=True)
print(
"WARNING: Not running as root. \
Skipping system dependency installation.",
flush=True,
)
print(
"Please ensure the listed packages are installed on your system:",
flush=True,
)
print(
" patchelf build-essential git cmake ninja-build \
autotools-dev automake meson libtool libtool-bin",
flush=True,
)
print("---\n", flush=True)
return
print("--- Running as root. Installing system dependencies... ---", flush=True)
apt_packages = [
"patchelf", # <-- Add patchelf here
"build-essential",
"git",
"cmake",
"ninja-build",
"autotools-dev",
"automake",
"meson",
"libtool",
"libtool-bin",
"pkg-config",
]
run_command(["apt-get", "update"])
run_command(["apt-get", "install", "-y"] + apt_packages)
print("--- System dependencies installed successfully. ---\n", flush=True)
def build_and_install_prerequisites(args):
"""Builds UCX and NIXL from source, creating a self-contained wheel."""
if not args.force_reinstall and is_pip_package_installed("nixl"):
print("--> NIXL is already installed. Nothing to do.", flush=True)
return
cached_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME)
if not args.force_reinstall and cached_wheel:
print(
f"\n--> Found self-contained wheel: \
{os.path.basename(cached_wheel)}.",
flush=True,
)
print("--> Installing from cache, skipping all source builds.", flush=True)
install_command = [sys.executable, "-m", "pip", "install", cached_wheel]
run_command(install_command)
print("\n--- Installation from cache complete. ---", flush=True)
return
print(
"\n--> No installed package or cached wheel found. \
Starting full build process...",
flush=True,
)
print("\n--> Installing auditwheel...", flush=True)
run_command([sys.executable, "-m", "pip", "install", "auditwheel"])
install_system_dependencies()
ucx_install_path = os.path.abspath(UCX_INSTALL_DIR)
print(f"--> Using wheel cache directory: {WHEELS_CACHE_HOME}", flush=True)
os.makedirs(WHEELS_CACHE_HOME, exist_ok=True)
# -- Step 1: Build UCX from source --
print("\n[1/3] Configuring and building UCX from source...", flush=True)
if not os.path.exists(UCX_DIR):
run_command(["git", "clone", UCX_REPO_URL, UCX_DIR])
ucx_source_path = os.path.abspath(UCX_DIR)
run_command(["git", "checkout", "v1.19.x"], cwd=ucx_source_path)
run_command(["./autogen.sh"], cwd=ucx_source_path)
configure_command = [
"./configure",
f"--prefix={ucx_install_path}",
"--enable-shared",
"--disable-static",
"--disable-doxygen-doc",
"--enable-optimizations",
"--enable-cma",
"--enable-devel-headers",
"--with-verbs",
"--enable-mt",
"--with-ze=no",
]
run_command(configure_command, cwd=ucx_source_path)
run_command(["make", "-j", str(os.cpu_count() or 1)], cwd=ucx_source_path)
run_command(["make", "install"], cwd=ucx_source_path)
print("--- UCX build and install complete ---", flush=True)
# -- Step 2: Build NIXL wheel from source --
print("\n[2/3] Building NIXL wheel from source...", flush=True)
if not os.path.exists(NIXL_DIR):
run_command(["git", "clone", NIXL_REPO_URL, NIXL_DIR])
else:
run_command(["git", "fetch", "--tags"], cwd=NIXL_DIR)
run_command(["git", "checkout", NIXL_VERSION], cwd=NIXL_DIR)
print(f"--> Checked out NIXL version: {NIXL_VERSION}", flush=True)
build_env = os.environ.copy()
build_env["PKG_CONFIG_PATH"] = os.path.join(ucx_install_path, "lib", "pkgconfig")
ucx_lib_path = os.path.join(ucx_install_path, "lib")
ucx_plugin_path = os.path.join(ucx_lib_path, "ucx")
existing_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
build_env["LD_LIBRARY_PATH"] = (
f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":")
)
build_env["LDFLAGS"] = "-Wl,-rpath,$ORIGIN"
print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True)
temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse")
run_command(
[
sys.executable,
"-m",
"pip",
"wheel",
".",
"--no-deps",
f"--wheel-dir={temp_wheel_dir}",
],
cwd=os.path.abspath(NIXL_DIR),
env=build_env,
)
# -- Step 3: Repair the wheel by copying UCX libraries --
print("\n[3/3] Repairing NIXL wheel to include UCX libraries...", flush=True)
unrepaired_wheel = find_nixl_wheel_in_cache(temp_wheel_dir)
if not unrepaired_wheel:
raise RuntimeError("Failed to find the NIXL wheel after building it.")
# We tell auditwheel to ignore the plugin that mesonpy already handled.
auditwheel_command = [
"auditwheel",
"repair",
"--exclude",
"libplugin_UCX.so", # <-- Exclude because mesonpy already includes it
unrepaired_wheel,
f"--wheel-dir={WHEELS_CACHE_HOME}",
]
run_command(auditwheel_command, env=build_env)
# --- CLEANUP ---
# No more temporary files to remove, just the temp wheelhouse
run_command(["rm", "-rf", temp_wheel_dir])
# --- END CLEANUP ---
newly_built_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME)
if not newly_built_wheel:
raise RuntimeError("Failed to find the repaired NIXL wheel.")
print(
f"--> Successfully built self-contained wheel: \
{os.path.basename(newly_built_wheel)}. Now installing...",
flush=True,
)
install_command = [
sys.executable,
"-m",
"pip",
"install",
"--no-deps", # w/o "no-deps", it will install cuda-torch
newly_built_wheel,
]
if args.force_reinstall:
install_command.insert(-1, "--force-reinstall")
run_command(install_command)
print("--- NIXL installation complete ---", flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Build and install UCX and NIXL dependencies."
)
parser.add_argument(
"--force-reinstall",
action="store_true",
help="Force rebuild and reinstall of UCX and NIXL \
even if they are already installed.",
)
args = parser.parse_args()
build_and_install_prerequisites(args)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the `if typing.TYPE_CHECKING:` guard,
**except** for a short whitelist.
"""
import ast
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Final
INIT_PATH: Final = Path("vllm/__init__.py")
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset(
{
"vllm.env_override",
}
)
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset(
{
".version",
}
)
def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name):
guard_is_type_checking = (
test.value.id == "typing" and test.attr == "TYPE_CHECKING"
)
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS:
violations.append(
(
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if (
_is_internal(node.module, level=node.level)
and module_as_written not in ALLOWED_FROM_MODULES
):
violations.append(
(
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
Visitor().visit(tree)
if violations:
_fail(violations)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
# The pickle and cloudpickle modules are known to be unsafe when deserializing
# data from potentially untrusted parties. They have resulted in multiple CVEs
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
# Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
# pickle
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py",
"vllm/distributed/utils.py",
"vllm/distributed/parallel_state.py",
"vllm/distributed/device_communicators/all_reduce_utils.py",
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py",
"tests/tokenizers_/test_hf.py",
"tests/utils_/test_hashing.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py",
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
# cloudpickle
"vllm/v1/executor/multiproc_executor.py",
"vllm/v1/executor/ray_executor.py",
"vllm/entrypoints/llm.py",
"tests/utils.py",
# pickle and cloudpickle
"vllm/v1/serial_utils.py",
}
PICKLE_RE = re.compile(
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)"
)
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
if PICKLE_RE.match(line):
print(
f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import"
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if filename in ALLOWED_FILES:
continue
returncode |= scan_file(filename)
return returncode
def test_regex():
test_cases = [
# Should match
("import pickle", True),
("import cloudpickle", True),
("import pickle as pkl", True),
("import cloudpickle as cpkl", True),
("from pickle import *", True),
("from cloudpickle import dumps", True),
("from pickle import dumps, loads", True),
("from cloudpickle import (dumps, loads)", True),
(" import pickle", True),
("\timport cloudpickle", True),
("from pickle import loads", True),
# Should not match
("import somethingelse", False),
("from somethingelse import pickle", False),
("# import pickle", False),
("print('import pickle')", False),
("import pickleas as asdf", False),
]
for i, (line, should_match) in enumerate(test_cases):
result = bool(PICKLE_RE.match(line))
assert result == should_match, (
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
)
print("All regex tests passed.")
if __name__ == "__main__":
if "--test-regex" in sys.argv:
test_regex()
else:
sys.exit(main())

View File

@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from enum import Enum
class SPDXStatus(Enum):
"""SPDX header status enumeration"""
EMPTY = "empty" # empty __init__.py
COMPLETE = "complete"
MISSING_LICENSE = "missing_license" # Only has copyright line
MISSING_COPYRIGHT = "missing_copyright" # Only has license line
MISSING_BOTH = "missing_both" # Completely missing
FULL_SPDX_HEADER = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"
)
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
def check_spdx_header_status(file_path):
"""Check SPDX header status of the file"""
with open(file_path, encoding="UTF-8") as file:
lines = file.readlines()
if not lines:
# Empty file
return SPDXStatus.EMPTY
# Skip shebang line
start_idx = 0
if lines and lines[0].startswith("#!"):
start_idx = 1
has_license = False
has_copyright = False
# Check all lines for SPDX headers (not just the first two)
for i in range(start_idx, len(lines)):
line = lines[i].strip()
if line == LICENSE_LINE:
has_license = True
elif line == COPYRIGHT_LINE:
has_copyright = True
# Determine status based on what we found
if has_license and has_copyright:
return SPDXStatus.COMPLETE
elif has_license and not has_copyright:
# Only has license line
return SPDXStatus.MISSING_COPYRIGHT
# Only has copyright line
elif not has_license and has_copyright:
return SPDXStatus.MISSING_LICENSE
else:
# Completely missing both lines
return SPDXStatus.MISSING_BOTH
def add_header(file_path, status):
"""Add or supplement SPDX header based on status"""
with open(file_path, "r+", encoding="UTF-8") as file:
lines = file.readlines()
file.seek(0, 0)
file.truncate()
if status == SPDXStatus.MISSING_BOTH:
# Completely missing, add complete header
if lines and lines[0].startswith("#!"):
# Preserve shebang line
file.write(lines[0])
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines[1:])
else:
# Add header directly
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines)
elif status == SPDXStatus.MISSING_COPYRIGHT:
# Only has license line, need to add copyright line
# Find the license line and add copyright line after it
for i, line in enumerate(lines):
if line.strip() == LICENSE_LINE:
# Insert copyright line after license line
lines.insert(
i + 1,
f"{COPYRIGHT_LINE}\n",
)
break
file.writelines(lines)
elif status == SPDXStatus.MISSING_LICENSE:
# Only has copyright line, need to add license line
# Find the copyright line and add license line before it
for i, line in enumerate(lines):
if line.strip() == COPYRIGHT_LINE:
# Insert license line before copyright line
lines.insert(i, f"{LICENSE_LINE}\n")
break
file.writelines(lines)
def main():
"""Main function"""
files_missing_both = []
files_missing_copyright = []
files_missing_license = []
for file_path in sys.argv[1:]:
status = check_spdx_header_status(file_path)
if status == SPDXStatus.MISSING_BOTH:
files_missing_both.append(file_path)
elif status == SPDXStatus.MISSING_COPYRIGHT:
files_missing_copyright.append(file_path)
elif status == SPDXStatus.MISSING_LICENSE:
files_missing_license.append(file_path)
else:
continue
# Collect all files that need fixing
all_files_to_fix = (
files_missing_both + files_missing_copyright + files_missing_license
)
if all_files_to_fix:
print("The following files are missing the SPDX header:")
if files_missing_both:
for file_path in files_missing_both:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_BOTH)
if files_missing_copyright:
for file_path in files_missing_copyright:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_COPYRIGHT)
if files_missing_license:
for file_path in files_missing_license:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_LICENSE)
sys.exit(1 if all_files_to_fix else 0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import sys
import regex as re
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
# the way allowed to import triton
ALLOWED_LINES = {
"from vllm.triton_utils import triton",
"from vllm.triton_utils import tl",
"from vllm.triton_utils import tl, triton",
}
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
def is_allowed_file(current_file: str) -> bool:
return current_file in ALLOWED_FILES
def is_forbidden_import(line: str) -> bool:
stripped = line.strip()
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
def parse_diff(diff: str) -> list[str]:
violations = []
current_file = None
current_lineno = None
skip_allowed_file = False
for line in diff.splitlines():
if line.startswith("+++ b/"):
current_file = line[6:]
skip_allowed_file = is_allowed_file(current_file)
elif skip_allowed_file:
continue
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
elif line.startswith("+") and not line.startswith("++"):
current_lineno += 1
code_line = line[1:]
if is_forbidden_import(code_line):
violations.append(
f"{current_file}:{current_lineno}: {code_line.strip()}"
)
return violations
def get_diff(diff_type: str) -> str:
if diff_type == "staged":
return subprocess.check_output(
["git", "diff", "--cached", "--unified=0"], text=True
)
elif diff_type == "unstaged":
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
else:
raise ValueError(f"Unknown diff_type: {diff_type}")
def main():
all_violations = []
for diff_type in ["staged", "unstaged"]:
try:
diff_output = get_diff(diff_type)
violations = parse_diff(diff_output)
all_violations.extend(violations)
except subprocess.CalledProcessError as e:
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
if all_violations:
print(
"❌ Forbidden direct `import triton` detected."
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
)
for v in all_violations:
print(f"{v}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
from pathlib import Path
import regex as re
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
ALLOWED_PATTERNS = [
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
re.compile(r"^\s*import\s+regex\s*$"),
]
def get_staged_python_files() -> list[str]:
try:
result = subprocess.run(
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
capture_output=True,
text=True,
check=True,
)
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
return [f for f in files if f.endswith(".py")]
except subprocess.CalledProcessError:
return []
def is_forbidden_import(line: str) -> bool:
line = line.strip()
return bool(
FORBIDDEN_PATTERNS.match(line)
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
)
def check_file(filepath: str) -> list[tuple[int, str]]:
violations = []
try:
with open(filepath, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if is_forbidden_import(line):
violations.append((line_num, line.strip()))
except (OSError, UnicodeDecodeError):
pass
return violations
def main() -> int:
files = get_staged_python_files()
if not files:
return 0
total_violations = 0
for filepath in files:
if not Path(filepath).exists():
continue
if filepath == "setup.py":
continue
violations = check_file(filepath)
if violations:
print(f"\n{filepath}:")
for line_num, line in violations:
print(f" Line {line_num}: {line}")
total_violations += 1
if total_violations > 0:
print(f"\n💡 Found {total_violations} violation(s).")
print("❌ Please replace 'import re' with 'import regex as re'")
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
print("✅ Allowed imports:")
print(" - import regex as re")
print(" - import regex") # noqa: E501
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generates specialized requirements files for nightly PyTorch testing.
This script reads the main test requirements input file (`requirements/test.in`)
and splits its content into two files:
1. `requirements/nightly_torch_test.txt`: Contains dependencies
except PyTorch-related.
2. `torch_nightly_test.txt`: Contains only PyTorch-related packages.
"""
input_file = "requirements/test.in"
output_file = "requirements/nightly_torch_test.txt"
# white list of packages that are not compatible with PyTorch nightly directly
# with pip install. Please add your package to this list if it is not compatible
# or make the dependency test fails.
white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"]
with open(input_file) as f:
lines = f.readlines()
skip_next = False
for line in lines:
if skip_next:
if line.startswith((" ", "\t")) or line.strip() == "":
continue
skip_next = False
if any(k in line.lower() for k in white_list):
skip_next = True
continue

158
tools/pre_commit/mypy.py Executable file
View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.
This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
"""
import subprocess
import sys
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
"vllm/executor",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/worker",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/attention",
"vllm/compilation",
"vllm/lora",
"vllm/model_executor",
# v1 related
"vllm/v1/attention",
"vllm/v1/kv_offload",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
"vllm/engine/arg_utils.py",
"vllm/model_executor/parallel_utils",
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Group changed files into different mypy calls.
Args:
changed_files: List of changed files.
Returns:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
# Skip files which should be ignored completely
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
return file_groups
def mypy(
targets: list[str],
python_version: str | None,
follow_imports: str | None,
file_group: str,
) -> int:
"""
Run mypy on the given targets.
Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
the default mypy version.
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
Returns:
The return code from mypy.
"""
args = ["mypy"]
if python_version is not None:
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
)
return returncode
if __name__ == "__main__":
sys.exit(main())

15
tools/pre_commit/png-lint.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
# Ensure that *.excalidraw.png files have the excalidraw metadata
# embedded in them. This ensures they can be loaded back into
# the tool and edited in the future.
find . -iname '*.excalidraw.png' | while read -r file; do
if git check-ignore -q "$file"; then
continue
fi
if ! grep -q "excalidraw+json" "$file"; then
echo "$file was not exported from excalidraw with 'Embed Scene' enabled."
exit 1
fi
done

22
tools/pre_commit/shellcheck.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
set -e
scversion="stable"
if [ -d "shellcheck-${scversion}" ]; then
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
if ! [ -x "$(command -v shellcheck)" ]; then
if [ "$(uname -s)" != "Linux" ] || [ "$(uname -m)" != "x86_64" ]; then
echo "Please install shellcheck: https://github.com/koalaman/shellcheck?tab=readme-ov-file#installing"
exit 1
fi
# automatic local install if linux x86_64
wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
# TODO - fix warnings in .buildkite/scripts/hardware_ci/run-amd-test.sh
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/scripts/hardware_ci/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'

View File

@@ -0,0 +1,81 @@
#!/bin/bash
# Update Dockerfile dependency graph when docker/Dockerfile changes.
# This script is designed to be used as a pre-commit hook.
set -euo pipefail
# Accept file paths as arguments
FILES=("$@")
# Check if docker/Dockerfile is among the provided files
if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then
echo "docker/Dockerfile has changed, attempting to update dependency graph..."
# Check if Docker is installed and running
if ! command -v docker &> /dev/null; then
echo "Warning: Docker command not found. Skipping Dockerfile graph update."
echo "Please install Docker to automatically update the graph: https://docs.docker.com/get-docker/"
exit 0
fi
if ! docker info &> /dev/null; then
echo "Warning: Docker daemon is not running. Skipping Dockerfile graph update."
echo "Please start Docker to automatically update the graph."
exit 0
fi
# Define the target file path
TARGET_GRAPH_FILE="docs/assets/contributing/dockerfile-stages-dependency.png"
# Ensure target directory exists
mkdir -p "$(dirname "$TARGET_GRAPH_FILE")"
# Store old image hash in a variable if the file exists
OLD_HASH=""
if [ -f "$TARGET_GRAPH_FILE" ]; then
OLD_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
fi
# Generate Dockerfile graph
echo "Running dockerfilegraph tool..."
docker run \
--rm \
--user "$(id -u):$(id -g)" \
--workdir /workspace \
--volume "$(pwd)":/workspace \
ghcr.io/patrickhoefler/dockerfilegraph:alpine \
--output png \
--dpi 200 \
--max-label-length 50 \
--filename docker/Dockerfile \
--legend
echo "Finding generated PNG file..."
# Check for Dockerfile.png in the root directory (most likely location)
if [ -f "./Dockerfile.png" ]; then
echo "Found generated file at: ./Dockerfile.png"
mv "./Dockerfile.png" "$TARGET_GRAPH_FILE"
else
# Try to find it elsewhere
DOCKERFILE_PNG=$(find . -name "Dockerfile.png" -type f | head -1)
if [ -n "$DOCKERFILE_PNG" ]; then
echo "Found generated file at: $DOCKERFILE_PNG"
mv "$DOCKERFILE_PNG" "$TARGET_GRAPH_FILE"
else
echo "Error: Could not find the generated PNG file"
find . -name "*.png" -type f -mmin -5
exit 1
fi
fi
# Check if the graph has changed
NEW_HASH=$(sha256sum "$TARGET_GRAPH_FILE")
if [ "$NEW_HASH" != "$OLD_HASH" ]; then
echo "Graph has changed. Please stage the updated file: $TARGET_GRAPH_FILE"
exit 1
else
echo "No changes in graph detected."
fi
fi
exit 0

View File

@@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Ensures all fields in a config dataclass have default values
and that each field has a docstring.
"""
import ast
import inspect
import sys
from itertools import pairwise
import regex as re
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
Adapted from https://davidism.com/attribute-docstrings/
https://davidism.com/mit-license/
"""
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (
not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)
):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
class ConfigValidator(ast.NodeVisitor):
def __init__(self): ...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
if set(decorators) == {"config", "dataclass"}:
validate_class(node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
self.generic_visit(node)
def validate_class(class_node: ast.ClassDef):
attr_docs = get_attr_docs(class_node)
for stmt in class_node.body:
# A field is defined as a class variable that has a type annotation.
if isinstance(stmt, ast.AnnAssign):
# Skip ClassVar and InitVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id in {"ClassVar", "InitVar"}
):
continue
if isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.",
stmt,
)
if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.",
stmt,
)
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id == "Union"
and isinstance(stmt.annotation.slice, ast.Tuple)
):
args = stmt.annotation.slice.elts
literal_args = [
arg
for arg in args
if isinstance(arg, ast.Subscript)
and isinstance(arg.value, ast.Name)
and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
f"Field '{field_name}' in {class_node.name} must "
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".",
stmt,
)
def validate_ast(tree: ast.stmt):
ConfigValidator().visit(tree)
def validate_file(file_path: str):
try:
print(f"Validating {file_path} config dataclasses ", end="")
with open(file_path, encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=file_path)
validate_ast(tree)
except ValueError as e:
print(e)
raise SystemExit(1) from e
else:
print("")
def fail(message: str, node: ast.stmt):
raise ValueError(f"❌ line({node.lineno}): {message}")
def main():
for filename in sys.argv[1:]:
# Only run for Python files in vllm/ or tests/
if not re.match(r"^(vllm|tests)/.*\.py$", filename):
continue
# Only run if the file contains @config
with open(filename, encoding="utf-8") as f:
if "@config" in f.read():
validate_file(filename)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,174 @@
# gputrc2graph.py
This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files
(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level
summaries and visualizations of GPU and non-GPU time. It is useful for
profiling and analyzing nsys profile output.
## Usage
### Command-line Arguments
- `--in_file`
**(required)**
List of input files and their metadata. Each entry should be in the format:
`<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>`
- `nsys-rep`: Path to the `.nsys-rep` file.
- `engine`: Engine name (e.g., `vllm`).
- `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`).
- `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without
profiling. Specify `0` to use the elapsed time from the nsys-rep file
(this may inflate non-GPU time if actual runtime without profiling is
less). Multiple entries can be provided, separated by spaces.
- `--out_dir`
Output directory for the generated CSV and HTML files.
If not specified, results are saved in the current directory.
- `--title`
Title for the HTML chart/visualization.
- `--nsys_cmd`
Path to the `nsys` command.
Default: `nsys` (assumes it is in your PATH).
Use this if `nsys` is not in your system PATH.
## Notes
- Make sure you have pandas installed.
- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH.
- For more details on available engines and models, see the help string in
the script or run:
```bash
python3 gputrc2graph.py --help
```
## Example 1: analyze a single profile
To analyze the GPU cycles for say, gpt-oss model with vLLM engine:
1. Run the following command to collect nsys profile, for vllm serve config.
```bash
nsys profile -t cuda -o run1 -f true --trace-fork-before-exec=true \
--cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \
vllm serve openai/gpt-oss-120b ...
```
where:
- DELAY: how many seconds to delay nsys from collecting profiles, needed so
that profiles aren't captured till vllm server has come up and load
generation starts.
- DURATION: how many seconds for nsys profile to run before generating the
profile. This should be > the duration of the run.
2. Run again, this time without collecting the profile, and get the total run
time in seconds. This value will be used by the script to calculate the
CPU(non-GPU) seconds for the analysis.
3. Say the run elapsed time is 306 seconds, from step #2. Run script to
analyze:
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,vllm,gpt-oss,306 \
--title "vLLM-gpt-oss profile"
```
The command will produce 2 files for analysis:
- result.html: this categorizes kernel names into different categories in a
stacked bar chart.
- result.csv: shows how the kernel names are mapped to the different
categories.
### HTML visualization with result.html
The html file shows the number of elapsed seconds due to different GPU
Substages or categories, which consist of moe_gemm (Mixture of Experts GEMM)
kernels the biggest category, at 148 seconds, followed by "attn" or attention
kernels. This lets the user prioritize the kernels to focus on for performance
optimizations.
![Example GPU Trace Visualization](images/html.png)
There's also an appended data table underneath the bar chart for copying out to other post-processing tools.
![Example GPU Trace Table](images/html_tbl.png)
### Kernel to category mapping with result.csv
Suppose the user would like to focus on improving triton kernels. It's not the
biggest consumer of cycles at 9.74 sec but perhaps it hasn't been optimized.
The next step is to use the result.csv to dive into what the kernels are which
compose the triton kernel GPU cycles. The following image shows that
triton_poi_fused__to_copy_add_addmm_cat_.. kernel to be the biggest
contributor to GPU cycles.
![Example GPU Trace csv](images/csv1.png)
## Example 2: analyze multiple profiles
Suppose the user has multiple nsys trace files, captured for different models,
say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU
time, something like the following command can be used.
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,vllm,llama,100 run2.nsys-rep,vllm,gpt-oss,102 \
--out_dir results \
--title "Comparison of vLLM Models"
```
The analysis process is similar to example 1 but now there will be multiple
stack bar charts that can be compared. The categories for the different
kernels will remain the same, so that it's easy to compare the GPU cycles for
the same categories.
Once a category is shown to have more cycles for one configuration than
another, the next step would be to use the csv file to see what kernels are
mapped into that category, and which kernels are taking the largest amount of
time which would cause a difference for the overall category.
## Example 3: add new classification for a new model
To create a new engine DEF with model ABC, just add another json file in the same directory as
gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications.
Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels
have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*"
or "*K*" in them, just add another .json file in the same directory as
gputrc2graph.py with the same format as the other json files, like the following:
```json
{
"DEF": {
"ABC": {
"H|I": "gemm",
"J|K": "attn",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
```
Each entry in the dictionary consists of:
- key: a regex used to classify the kernels
- value: the category to classify the kernels into.
The last 2 entries are common for all engine/models, consisting of CUDA memory
operations and a 'misc' for anything that's leftover and can't be classified.
When invoking gputrc2graph.py, specify a trace file with this new model/engine
like the following:
```bash
--infile new.nsys-rep,DEF,ABC,<runtime>
```
If the engine_DEF.json file already exists, just add the model as a new node in
the existing engine file, after the other models.

View File

@@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This generates gpu kernel analysis output from nsys rep. Will call nsys
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
csv and html output for analysis
"""
import argparse
import logging
import os
import regex as re
logger = logging.getLogger(__name__)
# helper data class for annotating kernels
def load_engine_model():
"""returns engine_model built from all json files in the current dir"""
import glob
import json
engine_model = {}
json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json"))
for fname in json_files:
with open(fname, encoding="utf-8") as f:
engine_model.update(json.load(f))
return engine_model
class GPUTrace2Graph:
"""
Parses output of nsys report, generates csv and bar chart output
"""
def __init__(self):
import pandas as pd # avoid importing till needed
self.pd = pd
self.pd.options.mode.copy_on_write = True
# helper functions for generating trace->summary csvs
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
logger.info("loading %s", in_file)
df = self.pd.read_csv(
in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"]
)
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
df = self.sum_non_overlapping_intervals(df)
# get ready to print table with elapsed times per kernel
df["Instances"] = 1
df_sum = df.groupby("Name", as_index=False).agg(
{"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"}
)
# generate csv
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv(
out_file, index=False
)
def sum_non_overlapping_intervals(self, df):
"""
returns new sorted df with Elapsed Time (ns) column using
vectorized operations
"""
logger.info("sorting %s trace records by start time", str(df.shape))
# Sort by start time and reset index
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
# Initialize elapsed time as duration
df["Elapsed Time (ns)"] = df["Duration (ns)"]
# Get numpy arrays for faster operations
starts = df["Start (ns)"].values
ends = df["End (ns)"].values
# Keep track of current interval end
current_end = ends[0]
display_units = int(len(df) / 100)
# Update current_end for overlapping intervals
for i in range(1, len(df)):
if i % display_units == 0:
print(f"processing trace: {int(i / len(df) * 100)} %", end="\r")
if starts[i] <= current_end:
if ends[i] > current_end:
# Partial overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
ends[i] - current_end
)
current_end = ends[i]
else:
# Complete overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
else:
# No overlap
current_end = ends[i]
return df
# functions for generating html files
def make_html(self, df, output_dir, title):
"""make html graph from df"""
import plotly.express as px
if df.empty:
return
output_name = output_dir + "/result"
if not title:
title = "Model_Engine"
x = "Model_Engine"
y = "Elapsed Time (sec)"
color = "Category"
""" generate kernel mapping table """
# Sort Model_Engine categories by last field after underscore
df["Model_Engine"] = self.pd.Categorical(
df["Model_Engine"],
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
)
df[["Model_Engine", color, "Instances", "Name", y]].sort_values(
by=color
).to_csv(f"{output_name}.csv", index=False)
graph = px.histogram(
df.round(2),
x=x,
y=y,
title=(f"{y} for {title}"),
color=color,
text_auto=True,
)
# wrap x axis labels
graph.update_xaxes(automargin=True)
graph.write_html(f"{output_name}.html")
"""
Generate data table with columns per Model_Engine into result.html
"""
pivot_df = df.pivot_table(
values="Elapsed Time (sec)",
index="Category",
columns="Model_Engine",
aggfunc="sum",
observed=False,
).round(2)
# Add sum row at bottom
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
pivot_df.fillna("").to_html("temp.html")
with (
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
open("temp.html", encoding="utf-8") as infile,
):
outfile.write(infile.read())
os.remove("temp.html")
print(
f"Finished generating: \n"
f" {output_name}.html for stack bar chart \n"
f" {output_name}.csv for Kernel-Category mapping"
)
def anno_gpu_kernname(self, df, mapping):
"""add "Category" column"""
def anno_gpu_kernname_helper(name):
for kern_name, val in mapping.items():
if re.search(kern_name, name):
return val
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
def make_nongpu_row(self, df, nongpu_sec):
"""this will append non-gpu time entry at end of df"""
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
nongpu_row["Instances"] = 1
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
return nongpu_row
def is_valid_file(self, base_file):
"""asserts if base_file is non-existent or is empty"""
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, (
f"{base_file} doesn't exist or is empty"
)
def should_gen_file(self, new_file, base_file):
"""figure out if new file should be generated from base_file"""
self.is_valid_file(base_file)
if (
os.path.exists(new_file)
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
and (os.path.getsize(base_file) > 0)
):
logger.info("reusing %s", new_file)
return False
else:
logger.info("generating %s", new_file)
return True
def gen_sum_file(self, file, nsys_cmd):
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
"""
import subprocess
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
# Walk through trace and get the total non-overlapped time
nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv"
sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv"
if self.should_gen_file(nsys_stats_file, file):
cmd = [
nsys_cmd,
"stats",
"-r",
"cuda_gpu_trace",
file,
"-o",
f"{file_dir}/{file_name}",
]
cmd_str = " ".join(cmd)
logger.info("+ %s", cmd_str)
# estimate time based on calibrated 240M/min
file_size_mb = os.path.getsize(file) / 1e6
logger.info(
"nsys stats for %.2f MB file expected to take %.2f min",
file_size_mb,
file_size_mb / 240,
)
try:
subprocess.run(cmd, check=True)
except Exception:
logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str)
exit(1)
logger.info("generating non-overalapped sum %s", sum_file)
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
self.is_valid_file(sum_file)
logger.info("Finished generating %s", sum_file)
return sum_file
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
"""generates graph and csv file from in_file into out_dir"""
# Initialize an empty DataFrame to store combined data
combined_df = self.pd.DataFrame()
for idx, (file, engine, model, total_sec) in enumerate(in_file):
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
sum_file = self.gen_sum_file(file, nsys_cmd)
# read kernel summary file
df = self.pd.read_csv(sum_file)
# annotate kernel to their categories
assert engine_model.get(engine), f"engine {engine} unknown"
assert engine_model[engine].get(model), f"model {model} unknown"
# remove nsys-rep from file_name for shorter x-label
file_name = file_name.replace(".nsys-rep", "")
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
self.anno_gpu_kernname(df, engine_model[engine][model])
# patch in non-gpu time
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
total_sec = round(float(total_sec), 1)
if total_sec < gpu_sec:
logger.warning(
"Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ",
total_sec,
gpu_sec,
)
total_sec = gpu_sec
nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec)
df = self.pd.concat([df, nongpu_row], ignore_index=True)
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
if out_dir is None:
out_dir = "."
else:
os.makedirs(out_dir, exist_ok=True)
# generate html file
self.make_html(combined_df, out_dir, title)
def parse_tuple(s):
return tuple(s.split(","))
def main():
logging.basicConfig(
format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO
)
parser = argparse.ArgumentParser(
description=(
"Process nsys rep and generate kernel non-overlapped cycles. \n"
"Example:\n"
"gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n"
"d2.nsys-rep,vllm,gpt-oss,102 "
'--out_dir results/ --title "Model=gpt-oss vLLM chart"'
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# load supported engine_model
engine_model_supported = load_engine_model()
# Get a string representation of supported engine/model combinations
engine_model_supported_str = ", ".join(
f"{engine}:[{', '.join(models.keys())}]"
for engine, models in engine_model_supported.items()
)
parser.add_argument(
"--in_file",
type=parse_tuple,
nargs="+",
help=(
"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
"separated by space. Elapsed_nonprofiled_sec is runtime without "
"profiling used to calculate non-gpu time. Specify 0 to use "
"elapsed time from nsys-rep but that might inflate non-gpu time. "
f"Available engine:[model] are: {engine_model_supported_str} "
f"Example: --infile d1.nsys-rep,vllm,llama,100 "
"d2.nsys-rep,vllm,gpt-oss,102"
),
required=True,
)
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
parser.add_argument("--title", help=("title for html chart"))
parser.add_argument(
"--nsys_cmd",
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
default="nsys",
)
args = parser.parse_args()
gputrace = GPUTrace2Graph()
gputrace.gen_graph(
args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported
)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 145 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

View File

@@ -0,0 +1,63 @@
{
"vllm": {
"llama": {
"fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm",
"gemm|nvjet": "gemm",
"moe|sigmoid": "moe",
"CatArrayBatched|prepare_inputs": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"_norm_|Norm": "norm",
"act_and_mul_": "activation",
"Rotary": "rope",
"SoftMax": "softmax",
"flash|fmha": "attn",
"elementwise": "elementwise",
"fp8_quant|cvt_": "quantize",
"reduce_kernel": "reduce",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"ds": {
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm",
"gemm|matmul|nvjet": "gemm",
"moe|sigmoid|expert": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"sbtopk": "topk",
"act_and_mul_": "activation",
"compute_position_kernel": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|cvt_": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn",
"triton": "triton_kernel",
"topk": "topk",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"gpt-oss": {
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm",
"gemm|matmul|nvjet": "gemm",
"moe|sigmoid|expert|splitKreduce": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"topk": "topk",
"act_and_mul_": "activation",
"compute_position_kernel": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|cvt_|quantize": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string
def flatten_entries(entry_cls, profile_dict: dict):
entries_and_depth = []
def get_entries(node, curr_depth=0):
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
for child in node["children"]:
get_entries(
child,
curr_depth=curr_depth + 1,
)
for root in profile_dict:
get_entries(root)
return entries_and_depth
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_inference/profiling.py",
)
parser.add_argument(
"--phase",
type=str,
required=True,
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number",
)
parser.add_argument(
"--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the layerwise model table",
)
args = parser.parse_args()
with open(args.json_trace) as f:
profile_data = json.load(f)
assert args.phase in profile_data, (
f"Cannot find phase {args.phase} in profile data. Choose one among"
f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
) # noqa
if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]
)
column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15)
elif args.table == "model":
entries_and_depths = flatten_entries(
ModelStatsEntry, profile_data[args.phase]["model_stats"]
)
column_widths = dict(
name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
)
# indent entry names based on the depth
entries = []
for entry, depth in entries_and_depths:
entry.name = indent_string(
entry.name,
indent=depth,
indent_style=lambda indent: "|" + "-" * indent + " ",
)
entries.append(entry)
TablePrinter(type(entries[0]), column_widths).print_table(entries)

View File

@@ -0,0 +1,631 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import json
import math
import os
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pandas as pd
## JSON parsing utils ####
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max(
[largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]]
)
def get_entries_at_depth(
depth: int,
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=(),
):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
# The tree is not tall enough!
entries_and_traces.append((node["entry"], trace))
return
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"],) + trace
for child in node["children"]:
get_entries_at_depth(
depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace
)
def fold_nodes(root: dict, nodes_to_fold: list[str]):
stack: list[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node["entry"]["name"] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
stack.append(child)
return root
## Operation name cleanup utils ####
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
if len(string) > 3:
string = string + "..."
return string
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len))
def abbreviate_known_names(name: str) -> str:
abbreviations = {
"MergedColumnParallelLinear": "MCPLinear",
"QKVParallelLinear": "QKVPLinear",
"RowParallelLinear": "RPLinear",
"weight=": "w=",
"bfloat16": "bf16",
"float16": "f16",
}
for key, value in abbreviations.items():
name = name.replace(key, value)
return name
def attempt_to_make_names_unique(entries_and_traces):
names, non_unique_names = (set(), set())
def all_the_same(items) -> bool:
return all(i == items[0] for i in items)
for entry, _ in entries_and_traces:
if entry["name"] in names:
non_unique_names.add(entry["name"])
else:
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [
(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name
]
zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(
i
for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)
),
None,
)
if first_trace_difference is None:
# can't create a unique name, leave the names as they
# are they will get aggregated by the pivot_table call
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join(
(entry["name"],) + trace[: first_trace_difference + 1]
)
## Operation grouping utils ####
"""
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
"""
def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame":
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name:
return True
# LoRA ops
def is_sgmv_shrink(op_name: str):
return "sgmv_shrink" in op_name
def is_sgmv_expand(op_name: str):
return "sgmv_expand" in op_name
def is_bgmv_shrink(op_name: str):
return "bgmv_shrink" in op_name
def is_bgmv_expand(op_name: str):
return "bgmv_expand" in op_name
def is_cutlass_gemm_op(op_name: str):
return (
"void cutlass::Kernel" in op_name
or "void cutlass::device_kernel" in op_name
)
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
return (
is_cutlass_gemm_op(op_name)
or "xmma_gemm" in op_name
or "gemv2T_kernel" in op_name
or "splitKreduce" in op_name
or "s16816gemm" in op_name
)
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or "memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
# nccl ops
def is_nccl_op(op_name: str):
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and (
"all_reduce" in op_name.lower() or "allreduce" in op_name.lower()
)
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and "gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and "broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
return "cross_device_reduce_1stage" in op_name
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce(op_name: str):
return "_C_custom_ar::all_reduce" in op_name
def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
headers = list(trace_df)
ops = copy.deepcopy(headers)
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
ops = list(filter(lambda x: x not in attention_ops, ops))
quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops))
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops))
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
ops = list(filter(lambda x: x not in mem_ops, ops))
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
ops = list(filter(lambda x: x not in elementwise_ops, ops))
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1)
if len(sgmv_shrink_ops):
trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1)
if len(sgmv_expand_ops):
trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1)
if len(bgmv_shrink_ops):
trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1)
if len(bgmv_expand_ops):
trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1)
if len(cutlass_gemm_ops):
trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1)
if len(mem_ops):
trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1)
if len(nccl_all_reduce_ops):
trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1
)
if len(nccl_gather_ops):
trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1)
if len(nccl_broadcast_ops):
trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1)
if len(nccl_other_ops):
trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df["cross_device_reduce_1stage_ops"] = trace_df[
cross_device_reduce_1stage_ops
].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df["cross_device_reduce_2stage_ops"] = trace_df[
cross_device_reduce_2stage_ops
].agg("sum", axis=1)
if len(custom_ar_all_reduce_ops):
trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg(
"sum", axis=1
)
if len(reduce_kernel_ops):
trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1)
trace_df.drop(
attention_ops
+ quant_ops
+ sgmv_shrink_ops
+ sgmv_expand_ops
+ bgmv_shrink_ops
+ bgmv_expand_ops
+ cutlass_gemm_ops
+ gemm_ops
+ rms_norm_ops
+ vocab_embed_ops
+ mem_ops
+ elementwise_ops
+ nccl_all_reduce_ops
+ nccl_gather_ops
+ nccl_broadcast_ops
+ nccl_other_ops
+ cross_device_reduce_1stage_ops
+ cross_device_reduce_2stage_ops
+ custom_ar_all_reduce_ops
+ reduce_kernel_ops,
axis=1,
inplace=True,
)
return trace_df
## Data plotting utils ####
def plot_trace_df(
traces_df: "pd.DataFrame",
plot_metric: str,
plot_title: str,
output: Path | None = None,
):
def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df["phase_desc"].to_list()
assert all([desc == descs[0] for desc in descs])
return descs[0]
phases = traces_df["phase"].unique()
phase_descs = [get_phase_description(traces_df, p) for p in phases]
traces_df = traces_df.pivot_table(
index="phase", columns="name", values=plot_metric, aggfunc="sum"
)
traces_df = group_trace_by_operations(traces_df)
# Make the figure
fig_size_x = max(5, len(phases))
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
# Draw the stacked bars
ops = list(traces_df)
bottom = [0] * len(phases)
for op in ops:
values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phase_descs, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha="center",
color="w",
weight="bold",
size=5,
)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches="tight")
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: list[str],
):
def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame":
def get_entries_and_traces(key: str):
entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
root = fold_nodes(root, json_nodes_to_fold)
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(
df: "pd.DataFrame", metric: str, top_k: int = 9
) -> "pd.DataFrame":
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
return df
def get_phase_description(key: str) -> str:
num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"]
if num_running_seqs is not None:
return f"{key}-seqs-{num_running_seqs}"
else:
return key
# Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
# Attempt some cleanup
if make_names_unique:
for trace in traces:
attempt_to_make_names_unique(trace)
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces)
)
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k
),
trace_dfs,
)
)
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df["phase"] = step_key
trace_df["phase_desc"] = get_phase_description(step_key)
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
# Add a derived metric `cuda_time_ms`
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
traces_df = traces_df.fillna(0)
return traces_df
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get("sparsity", None)
run_type = (
f"Run {context['num_steps']} steps"
if context["num_steps"]
else (
f"Complete {context['complete_num_requests_per_step']} per "
f"step; Run till completion"
)
)
return (
f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}"
)
profile_json = None
with open(json_trace) as f:
profile_json = json.load(f)
assert profile_json is not None
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert step_traces[0] == "context"
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[:: args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
prefill_traces = prepare_data(profile_json, prefills)
decode_traces = prepare_data(profile_json, decodes)
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(
prefill_traces,
plot_metric,
"prefill " + plot_title_suffix,
output_directory / Path("prefill.png"),
)
plot_trace_df(
decode_traces,
plot_metric,
"decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py",
)
parser.add_argument(
"--output-directory", type=str, required=False, help="Directory to output plots"
)
parser.add_argument(
"--level", type=str, default="module", choices=["module", "kernel"]
)
parser.add_argument(
"--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.",
)
parser.add_argument(
"--fold-json-node",
nargs="+",
default=["Sampler", "LogitsProcessor"],
help="Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children",
)
parser.add_argument(
"--plot-metric",
type=str,
default="cuda_time_ms",
help="Metric to plot. some options are cuda_time_ms, \
pct_cuda_time",
)
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step",
)
args = parser.parse_args()
# Prepare/Extract relevant args
make_names_unique = False
if args.level == "module":
depth = -2
make_names_unique = True
elif args.level == "kernel":
depth = -1
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = (
args.output_directory if args.output_directory else Path(args.json_trace).parent
)
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(
Path(args.json_trace),
output_directory,
depth,
args.plot_metric,
make_names_unique,
args.top_k,
args.fold_json_node,
)

View File

@@ -0,0 +1,325 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py
"""Summarize the last ninja build, invoked with ninja's -C syntax.
> python3 tools/report_build_time_ninja.py -C build/..
Typical output looks like this:
```
Longest build steps for .cpp.o:
1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time)
2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time)
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
Longest build steps for .cu.o:
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time)
15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time)
37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time)
43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time)
344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time)
1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism)
134 build steps completed, average of 0.12/s
```
"""
import argparse
import errno
import fnmatch
import os
import sys
from collections import defaultdict
# The number of long build times to report:
long_count = 10
# The number of long times by extension to report
long_ext_count = 10
class Target:
"""Represents a single line read for a .ninja_log file."""
def __init__(self, start, end):
"""Creates a target object by passing in the start/end times in seconds
as a float."""
self.start = start
self.end = end
# A list of targets, appended to by the owner of this object.
self.targets = []
self.weighted_duration = 0.0
def Duration(self):
"""Returns the task duration in seconds as a float."""
return self.end - self.start
def SetWeightedDuration(self, weighted_duration):
"""Sets the duration, in seconds, passed in as a float."""
self.weighted_duration = weighted_duration
def WeightedDuration(self):
"""Returns the task's weighted duration in seconds as a float.
Weighted_duration takes the elapsed time of the task and divides it
by how many other tasks were running at the same time. Thus, it
represents the approximate impact of this task on the total build time,
with serialized or serializing steps typically ending up with much
longer weighted durations.
weighted_duration should always be the same or shorter than duration.
"""
# Allow for modest floating-point errors
epsilon = 0.000002
if self.weighted_duration > self.Duration() + epsilon:
print("{} > {}?".format(self.weighted_duration, self.Duration()))
assert self.weighted_duration <= self.Duration() + epsilon
return self.weighted_duration
def DescribeTargets(self):
"""Returns a printable string that summarizes the targets."""
# Some build steps generate dozens of outputs - handle them sanely.
# The max_length was chosen so that it can fit most of the long
# single-target names, while minimizing word wrapping.
result = ", ".join(self.targets)
max_length = 65
if len(result) > max_length:
result = result[:max_length] + "..."
return result
# Copied with some modifications from ninjatracing
def ReadTargets(log, show_all):
"""Reads all targets from .ninja_log file |log_file|, sorted by duration.
The result is a list of Target objects."""
header = log.readline()
assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format(
header
)
targets_dict = {}
last_end_seen = 0.0
for line in log:
parts = line.strip().split("\t")
if len(parts) != 5:
# If ninja.exe is rudely halted then the .ninja_log file may be
# corrupt. Silently continue.
continue
start, end, _, name, cmdhash = parts # Ignore restart.
# Convert from integral milliseconds to float seconds.
start = int(start) / 1000.0
end = int(end) / 1000.0
if not show_all and end < last_end_seen:
# An earlier time stamp means that this step is the first in a new
# build, possibly an incremental build. Throw away the previous
# data so that this new build will be displayed independently.
# This has to be done by comparing end times because records are
# written to the .ninja_log file when commands complete, so end
# times are guaranteed to be in order, but start times are not.
targets_dict = {}
target = None
if cmdhash in targets_dict:
target = targets_dict[cmdhash]
if not show_all and (target.start != start or target.end != end):
# If several builds in a row just run one or two build steps
# then the end times may not go backwards so the last build may
# not be detected as such. However in many cases there will be a
# build step repeated in the two builds and the changed
# start/stop points for that command, identified by the hash,
# can be used to detect and reset the target dictionary.
targets_dict = {}
target = None
if not target:
targets_dict[cmdhash] = target = Target(start, end)
last_end_seen = end
target.targets.append(name)
return list(targets_dict.values())
def GetExtension(target, extra_patterns):
"""Return the file extension that best represents a target.
For targets that generate multiple outputs it is important to return a
consistent 'canonical' extension. Ultimately the goal is to group build steps
by type."""
for output in target.targets:
if extra_patterns:
for fn_pattern in extra_patterns.split(";"):
if fnmatch.fnmatch(output, "*" + fn_pattern + "*"):
return fn_pattern
# Not a true extension, but a good grouping.
if output.endswith("type_mappings"):
extension = "type_mappings"
break
# Capture two extensions if present. For example: file.javac.jar should
# be distinguished from file.interface.jar.
root, ext1 = os.path.splitext(output)
_, ext2 = os.path.splitext(root)
extension = ext2 + ext1 # Preserve the order in the file name.
if len(extension) == 0:
extension = "(no extension found)"
if ext1 in [".pdb", ".dll", ".exe"]:
extension = "PEFile (linking)"
# Make sure that .dll and .exe are grouped together and that the
# .dll.lib files don't cause these to be listed as libraries
break
if ext1 in [".so", ".TOC"]:
extension = ".so (linking)"
# Attempt to identify linking, avoid identifying as '.TOC'
break
# Make sure .obj files don't get categorized as mojo files
if ext1 in [".obj", ".o"]:
break
# Jars are the canonical output of java targets.
if ext1 == ".jar":
break
# Normalize all mojo related outputs to 'mojo'.
if output.count(".mojom") > 0:
extension = "mojo"
break
return extension
def SummarizeEntries(entries, extra_step_types):
"""Print a summary of the passed in list of Target objects."""
# Create a list that is in order by time stamp and has entries for the
# beginning and ending of each build step (one time stamp may have multiple
# entries due to multiple steps starting/stopping at exactly the same time).
# Iterate through this list, keeping track of which tasks are running at all
# times. At each time step calculate a running total for weighted time so
# that when each task ends its own weighted time can easily be calculated.
task_start_stop_times = []
earliest = -1
latest = 0
total_cpu_time = 0
for target in entries:
if earliest < 0 or target.start < earliest:
earliest = target.start
if target.end > latest:
latest = target.end
total_cpu_time += target.Duration()
task_start_stop_times.append((target.start, "start", target))
task_start_stop_times.append((target.end, "stop", target))
length = latest - earliest
weighted_total = 0.0
# Sort by the time/type records and ignore |target|
task_start_stop_times.sort(key=lambda times: times[:2])
# Now we have all task start/stop times sorted by when they happen. If a
# task starts and stops on the same time stamp then the start will come
# first because of the alphabet, which is important for making this work
# correctly.
# Track the tasks which are currently running.
running_tasks = {}
# Record the time we have processed up to so we know how to calculate time
# deltas.
last_time = task_start_stop_times[0][0]
# Track the accumulated weighted time so that it can efficiently be added
# to individual tasks.
last_weighted_time = 0.0
# Scan all start/stop events.
for event in task_start_stop_times:
time, action_name, target = event
# Accumulate weighted time up to now.
num_running = len(running_tasks)
if num_running > 0:
# Update the total weighted time up to this moment.
last_weighted_time += (time - last_time) / float(num_running)
if action_name == "start":
# Record the total weighted task time when this task starts.
running_tasks[target] = last_weighted_time
if action_name == "stop":
# Record the change in the total weighted task time while this task
# ran.
weighted_duration = last_weighted_time - running_tasks[target]
target.SetWeightedDuration(weighted_duration)
weighted_total += weighted_duration
del running_tasks[target]
last_time = time
assert len(running_tasks) == 0
# Warn if the sum of weighted times is off by more than half a second.
if abs(length - weighted_total) > 500:
print(
"Warning: Possible corrupt ninja log, results may be "
"untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format(
length, weighted_total
)
)
entries_by_ext = defaultdict(list)
for target in entries:
extension = GetExtension(target, extra_step_types)
entries_by_ext[extension].append(target)
for key, values in entries_by_ext.items():
print(" Longest build steps for {}:".format(key))
values.sort(key=lambda x: x.WeightedDuration())
for target in values[-long_count:]:
print(
" {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format(
target.WeightedDuration(),
target.DescribeTargets(),
target.Duration(),
)
)
print(
" {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x "
"parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length)
)
print(
" {} build steps completed, average of {:1.2f}/s".format(
len(entries), len(entries) / (length)
)
)
def main():
log_file = ".ninja_log"
parser = argparse.ArgumentParser()
parser.add_argument("-C", dest="build_directory", help="Build directory.")
parser.add_argument(
"-s",
"--step-types",
help="semicolon separated fnmatch patterns for build-step grouping",
)
parser.add_argument("--log-file", help="specific ninja log file to analyze.")
args, _extra_args = parser.parse_known_args()
if args.build_directory:
log_file = os.path.join(args.build_directory, log_file)
if args.log_file:
log_file = args.log_file
if args.step_types:
# Make room for the extra build types.
global long_ext_count
long_ext_count += len(args.step_types.split(";"))
try:
with open(log_file) as log:
entries = ReadTargets(log, False)
SummarizeEntries(entries, args.step_types)
except OSError:
print("Log file {!r} not found, no build summary created.".format(log_file))
return errno.ENOENT
if __name__ == "__main__":
sys.exit(main())

95
tools/vllm-tpu/build.sh Executable file
View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -e # Exit immediately if a command exits with a non-zero status.
# Script to build VLLM wheel for TPU with an optional version override.
SCRIPT_PATH_PARAM="$0"
TOOLS_DIR=$(cd "$(dirname "$SCRIPT_PATH_PARAM")" && pwd) # Absolute path to the script's directory
REPO_ROOT=$(cd "$TOOLS_DIR/../../" && pwd) # Absolute path to the repo root
VLLM_DIR="$REPO_ROOT/" # Path to the vllm sources
CHANGE_FILE_LIST=(
"vllm/entrypoints/cli/main.py"
"vllm/entrypoints/cli/run_batch.py"
"vllm/utils/__init__.py"
"vllm/platforms/__init__.py"
)
# Ensure we are not running from within the vllm directory if SCRIPT_PATH_PARAM is relative like "."
if [ "$TOOLS_DIR" = "$VLLM_DIR" ]; then
echo "Error: This script should not be run from the vllm directory directly if using relative paths."
echo "Place it in a subdirectory like 'tools/vllm-tpu' and run it from the repository root or via its full path."
exit 1
fi
# Optional version argument
if [ -n "$1" ]; then
USER_VERSION="$1"
export VLLM_VERSION_OVERRIDE="$USER_VERSION"
echo "User defined version: $USER_VERSION"
else
echo "No version override supplied. Using default version from source."
fi
PYPROJECT_FILE="$VLLM_DIR/pyproject.toml"
# Backup and update the project name.
if ! grep -q "name = \"vllm-tpu\"" "$PYPROJECT_FILE"; then
echo "Patching pyproject.toml project name to vllm-tpu..."
cp "$PYPROJECT_FILE" "${PYPROJECT_FILE}.bak"
sed -i '0,/^name = "vllm"/s//name = "vllm-tpu"/' "$PYPROJECT_FILE"
echo "Patching ${CHANGE_FILE_LIST[@]} vllm to vllm-tpu..."
# patching
# importlib.metadata.version('vllm') -> importlib.metadata.version('vllm-tpu')
# importlib.metadata.version("vllm") -> importlib.metadata.version("vllm-tpu")
# importlib.metadata.metadata('vllm') -> importlib.metadata.metadata('vllm-tpu')
# importlib.metadata.metadata("vllm") -> importlib.metadata.metadata("vllm-tpu")
# version('vllm') -> version('vllm-tpu')
# version("vllm") -> version("vllm-tpu")
sed -i \
-e "s/importlib.metadata.version(\(['\"]\)vllm\1)/importlib.metadata.version(\1vllm-tpu\1)/" \
-e "s/importlib.metadata.metadata(\(['\"]\)vllm\1)/importlib.metadata.metadata(\1vllm-tpu\1)/" \
-e "s/version(\(['\"]\)vllm\1)/version(\1vllm-tpu\1)/" \
"${CHANGE_FILE_LIST[@]}"
PATCHED=true
else
PATCHED=false
fi
# Navigate to the vllm directory
cd "$VLLM_DIR"
# Cleanup function to be called on exit or error
cleanup() {
echo "Cleaning up..."
if [ "$PATCHED" = true ]; then
echo "Restoring original pyproject.toml..."
cp "${PYPROJECT_FILE}.bak" "$PYPROJECT_FILE"
rm -f "${PYPROJECT_FILE}.bak"
echo "Restoring vllm code..."
sed -i \
-e "s/importlib.metadata.version(\(['\"]\)vllm-tpu\1)/importlib.metadata.version(\1vllm\1)/" \
-e "s/importlib.metadata.metadata(\(['\"]\)vllm-tpu\1)/importlib.metadata.metadata(\1vllm\1)/" \
-e "s/version(\(['\"]\)vllm-tpu\1)/version(\1vllm\1)/" \
"${CHANGE_FILE_LIST[@]}"
fi
}
trap cleanup EXIT HUP INT QUIT PIPE TERM # Register cleanup function to run on script exit and various signals
echo "Updating pyproject.toml completed. Proceeding with build..."
echo "Building wheel for TPU..."
rm -rf dist/
mkdir -p dist/
# User confirmed to use 'python -m build' directly
if ! VLLM_TARGET_DEVICE=tpu python -m build; then
echo "Error: Python build command failed. Check if 'python -m build' works and the 'build' module is installed."
exit 1
fi
trap - EXIT HUP INT QUIT PIPE TERM
cleanup
exit 0