Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

1131
vllm/compilation/backends.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Protocol
from vllm.config import CUDAGraphMode, VllmConfig
class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""
def __init__(
self,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
**kwargs: Any,
) -> None:
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_mode (CUDAGraphMode): The style of the static
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Executes the wrapped callable.
If the current runtime mode in the ForwardContext matches the runtime
mode of this instance, it replays the CUDAGraph or captures it using
the callable if it hasn't been captured yet. Otherwise, it calls the
original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError

516
vllm/compilation/caching.py Normal file
View File

@@ -0,0 +1,516 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import os
import pickle
from collections.abc import Callable, Sequence
from typing import Any, Literal
from unittest.mock import patch
import torch
from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
try:
from torch._dynamo.aot_compile import SerializableCallable
except ImportError:
SerializableCallable = object
assert isinstance(SerializableCallable, type)
logger = init_logger(__name__)
class StandaloneCompiledArtifacts:
"""Storage for standalone compiled artifacts with content-based deduplication.
Deduplication works via a two-level indirection:
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
When inserting, we compute the SHA256 hash of the bytes. If the hash
already exists in `submodule_bytes_store`, we reuse the existing entry
rather than storing duplicate bytes. This is common because submodules
often compile to identical artifacts (e.g., identical transformer layers
split on attn)
"""
def __init__(self) -> None:
# dict from submodule name to byte hash
self.submodule_bytes: dict[str, str] = {}
# dict from byte hash to bytes
self.submodule_bytes_store: dict[str, bytes] = {}
# dict from byte hash to loaded module
self.loaded_submodule_store: dict[str, Any] = {}
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
if hex_digest not in self.submodule_bytes_store:
self.submodule_bytes_store[hex_digest] = entry
logger.debug(
"inserting new artifact for submod %s with shape %s "
"(%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
else:
logger.debug(
"reusing existing cache artifact for submod %s "
"with shape %s (%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
def get(self, submod_name: str, shape: str) -> bytes:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.submodule_bytes_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def get_loaded(self, submod_name: str, shape: str) -> Any:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.loaded_submodule_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def size_bytes(self) -> int:
return sum(len(entry) for entry in self.submodule_bytes_store.values())
def num_artifacts(self) -> int:
return len(self.submodule_bytes_store)
def num_entries(self) -> int:
return len(self.submodule_bytes)
def submodule_names(self) -> list[str]:
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
return list(dict.fromkeys(names))
def load_all(self) -> None:
import concurrent.futures
# check already loaded
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
return
from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)
with concurrent.futures.ThreadPoolExecutor() as executor:
entries = list(self.submodule_bytes_store.values())
loaded_entries = list(executor.map(_load_entry, entries))
for i, k in enumerate(self.submodule_bytes_store.keys()):
self.loaded_submodule_store[k] = loaded_entries[i]
logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
It also implements a serialization interface to support PyTorch's precompile
with custom backend, so that we can save and load the compiled function on
disk. There's no need to wrap around the compiled function if we don't want
to serialize them in particular cases.
Right now serialization for the custom backend is done via
serializing the Dynamo fx graph plus example inputs.
"""
def __init__(
self,
graph_module: torch.fx.GraphModule,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.prefix = prefix
self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
if sym_input is not None:
self.shape_env = sym_input.node.shape_env
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.optimized_call(*args, **kwargs)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(
self: GraphPickler, obj: Any
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_unpickler")
):
return obj._torch_unpickler, (obj._torch_handler_name,)
if isinstance(obj, FakeTensorMode):
return type(None), ()
return graph_reducer_override(self, obj)
if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
else:
# mask off all tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend:
(
standalone_compile_artifacts,
sym_shape_indices_map,
returns_tuple_map,
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
state["standalone_compile_artifacts"] = standalone_compile_artifacts
state["sym_shape_indices_map"] = sym_shape_indices_map
state["returns_tuple_map"] = returns_tuple_map
return pickle.dumps(state)
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
)
return fn
# Fall back to standard VllmBackend
from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False)
vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp if inp is not None else example_inputs[i]
for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call)
return fn
@property
def co_name(self) -> Literal["VllmSerializableFunction"]:
"""
Used for depyf debugging.
"""
return "VllmSerializableFunction"
def reconstruct_serializable_fn_from_mega_artifact(
state: dict[str, Any],
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool],
) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts.
This function reconstructs a callable model from pre-compiled inductor
artifacts without re-running the compilation. It:
1. Loads all cached artifacts
2. Builds compiled callables for each submodule/shape
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
4. Wraps with cudagraph if needed
5. Returns the final VllmSerializableFunction
Note: This function shares similar logic with PiecewiseCompileInterpreter
in backends.py. Both create PiecewiseBackend instances and wrap them with
cudagraph. The key difference is:
- this function: PiecewiseBackend receives pre-compiled runnables
(compiled_runnables is set, graph is None)
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
to compile (graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
Args:
state: Deserialized state dict containing graph_module, example_inputs,
prefix, sym_tensor_indices, is_encoder, etc.
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
pre-compiled artifacts for each submodule/shape combination.
vllm_config: The vLLM configuration.
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
returns_tuple_map: Mapping from submod_name to returns_tuple.
Returns:
A VllmSerializableFunction that can be called directly.
"""
from vllm.compilation.backends import (
VllmBackend,
make_copy_and_call,
wrap_with_cudagraph_if_needed,
)
from vllm.compilation.piecewise_backend import PiecewiseBackend
prefix = state["prefix"]
is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
compiled_callables.setdefault(submod_name, {})[shape_str] = (
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
)
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
os.makedirs(dummy_cache_dir, exist_ok=True)
vllm_backend.compiler_manager.initialize_cache(
cache_dir=dummy_cache_dir,
disable_cache=True,
prefix=prefix,
)
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
missing = set(submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
for i, submod_name in enumerate(submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
sym_shape_indices = sym_shape_indices_map[submod_name]
returns_tuple = returns_tuple_map[submod_name]
runnables = compiled_callables[submod_name]
piecewise_backend = PiecewiseBackend(
graph=None, # not needed for cached artifacts
vllm_config=vllm_config,
piecewise_compile_index=i,
total_piecewise_compiles=len(submod_names),
sym_shape_indices=sym_shape_indices,
vllm_backend=vllm_backend,
returns_tuple=returns_tuple,
compiled_runnables=runnables,
)
is_first = i == 0
is_last = i == len(submod_names) - 1
wrapped_backend = wrap_with_cudagraph_if_needed(
piecewise_backend,
vllm_config,
compilation_config,
is_first,
is_last,
)
split_gm.__dict__[submod_name] = wrapped_backend
logger.debug(
"Replaced submodule %s with piecewise backend from cache",
submod_name,
)
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
input_buffers = [
torch.empty_like(
state["example_inputs"][idx], device=vllm_config.device_config.device
)
for idx in sym_tensor_indices
]
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
else:
optimized_call = split_gm
fn = VllmSerializableFunction(
**state,
optimized_call=optimized_call,
vllm_backend=None,
)
return fn
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = hash_factors(envs.compile_factors())
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
# 2. inductor factors if applicable
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
factors.extend(get_inductor_factors())
return factors
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
hash_content = []
for filepath, content in items:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
result: str = safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
return result
def _compute_code_hash(files: set[str]) -> str:
logger.debug(
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
)
file_contents = {}
for filepath in files:
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
if not os.path.isfile(filepath):
file_contents[filepath] = ""
else:
with open(filepath) as f:
file_contents[filepath] = f.read()
return _compute_code_hash_with_content(file_contents)

View File

@@ -0,0 +1,660 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import os
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Literal
from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
def compute_hash(self, vllm_config: VllmConfig) -> str:
"""
Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model.
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return ""
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
or a range [5, 8].
Right now we only support one variable in ranges for all inputs,
which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return None, None
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise NotImplementedError("caching is not supported")
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return []
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return ""
def get_inductor_factors() -> list[Any]:
factors: list[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors
def is_compile_cache_enabled(
vllm_additional_inductor_config: dict[str, Any],
) -> bool:
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
"force_disable_caches", False
)
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
# with torch.compiler.config.force_disable_caches when minimum PyTorch
# version reaches 2.10
return (
not envs.VLLM_DISABLE_COMPILE_CACHE
and not torch._inductor.config.force_disable_caches
and not vllm_inductor_config_disable_cache
)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
self.cache_dir = cache_dir
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, compile_range)
set_functorch_config()
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile
supports_aot = is_torch_equal_or_newer("2.10.0")
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
logger.error(
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
"is enabled but PyTorch version does not support 'aot' "
"parameter in standalone_compile. This requires PyTorch "
"2.10.0+. Falling back to non-AOT mode."
)
compile_kwargs = {
"dynamic_shapes": dynamic_shapes,
"options": {
"config_patches": current_config,
},
}
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
# only add 'aot' parameter if both supported and enabled...
# this will set bundled_autograd_cache
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
# Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx:
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact
assert isinstance(compiled_graph, AOTCompiledArtifact)
assert hasattr(compiled_graph, "serialize")
# just return the compiled graph and a key
# since we can serialize the bytes using to_bytes
# and reload it using the key when reading
return compiled_graph, None
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
def is_saveable_2_10(compiled_artifact):
# can just use compiled_artifact.is_saveable in 2.11
if compiled_artifact._artifacts is None:
return False
_, cache_info = compiled_artifact._artifacts
return len(cache_info.aot_autograd_artifacts) == 1
if is_compile_cache_enabled(compiler_config):
if not is_saveable_2_10(compiled_graph):
raise RuntimeError(
"The compiled artifact is not serializable. This usually means "
"that the model code has something that is not serializable "
"by torch.compile in it. You can fix this by either "
"figuring out what is not serializable and rewriting it, "
"filing a bug report, "
"or suppressing this error by "
"disabling vLLM's compilation cache via "
"VLLM_DISABLE_COMPILE_CACHE=1 "
"(this will greatly increase vLLM server warm start times)."
)
compiled_graph.save(path=path, format=self.save_format)
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path)
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format=self.save_format
)
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a subdirectory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, compile_range)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import compiled_fx_graph_hash
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if (
not file_path.startswith(self.base_cache_dir)
and compiled_fn.__closure__ is not None
):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
break
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
# for hijacking the hash of the compiled graph
stack.enter_context(
patch(
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash,
)
)
# for providing a dummy shape environment
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env,
)
)
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env,
)
)
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache,
)
)
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
# Disable remote caching. When these are on, on remote cache-hit,
# the monkey-patched functions never actually get called.
# vLLM today assumes and requires the monkey-patched functions to
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config,
)
# Turn off the checks if we disable the compilation cache.
if is_compile_cache_enabled(compiler_config):
if hash_str is None:
raise RuntimeError(
"vLLM failed to compile the model. The most "
"likely reason for this is that a previous compilation "
"failed, leading to a corrupted compilation artifact. "
"We recommend trying to "
"remove ~/.cache/vllm/torch_compile_cache and try again "
"to see the real issue. "
)
assert file_path is not None, (
"failed to get the file path of the compiled graph"
)
return compiled_graph, (hash_str, file_path)
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove "
f"the cache directory and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different mode of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else:
return contextlib.nullcontext()
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
)
def set_functorch_config() -> None:
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
torch._functorch.config.bundled_autograd_cache = False
class EagerAdaptor(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return graph, None

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
@dataclasses.dataclass
class CompilationCounter:
num_models_seen: int = 0
num_graphs_seen: int = 0
# including the splitting ops
num_piecewise_graphs_seen: int = 0
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_backend_compilations: int = 0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers: int = 0
# Number of CUDAGraphs captured
num_cudagraph_captured: int = 0
# InductorAdapter.compile calls
num_inductor_compiles: int = 0
# EagerAdapter.compile calls
num_eager_compiles: int = 0
# The number of time vLLM's compiler cache entry was updated
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
@contextmanager
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
old = self.clone()
yield
for k, v in kwargs.items():
assert getattr(self, k) - getattr(old, k) == v, (
f"{k} not as expected, before it is {getattr(old, k)}"
f", after it is {getattr(self, k)}, "
f"expected diff is {v}"
)
compilation_counter = CompilationCounter()

View File

@@ -0,0 +1,332 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from unittest.mock import patch
import torch
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
num_unpadded_tokens: int
num_padded_tokens: int
num_paddings: int
runtime_mode: str
class CUDAGraphLogging:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS = [
"Unpadded Tokens",
"Padded Tokens",
"Num Paddings",
"Runtime Mode",
"Count",
]
def __init__(
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
) -> None:
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
self.settings_header = (
"**CUDAGraph Config Settings:**\n\n"
f"- Mode: {self.cg_mode}\n"
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
"**CUDAGraph Stats:**\n\n"
)
def reset(self) -> None:
self.stats: list[CUDAGraphStat] = []
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
stats_counts = Counter(self.stats)
# Convert stats to rows of strings, in descending order of observed frequencies
rows = []
for stat, count in sorted(
stats_counts.items(), key=lambda item: item[1], reverse=True
):
rows.append(
[
str(stat.num_unpadded_tokens),
str(stat.num_padded_tokens),
str(stat.num_paddings),
stat.runtime_mode,
str(count),
]
)
# Calculate column widths (max of header and data)
col_widths = []
for i, header_text in enumerate(self.COLUMN_HEADERS):
max_width = len(header_text)
for row in rows:
max_width = max(max_width, len(row[i]))
col_widths.append(max_width)
table_header_list = [
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
]
table_header = "| " + " | ".join(table_header_list) + " |\n"
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
# Create data rows with proper alignment
data_rows = []
for row in rows:
formatted_row = [
str(val).ljust(width) for val, width in zip(row, col_widths)
]
data_rows.append("| " + " | ".join(formatted_row) + " |")
return (
self.settings_header
+ table_header
+ table_separator
+ "\n".join(data_rows)
+ "\n"
)
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
if not self.stats:
return
log_fn(self.generate_metric_table())
self.reset()
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: torch.cuda.CUDAGraph | None = None
output: Any | None = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: list[int] | None = None
@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the cudagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for cudagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: CUDAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(
self,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
) -> None:
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
)
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable
def weak_ref_tensors_with_intermediate(self, output):
if isinstance(output, IntermediateTensors):
intermediate_states = IntermediateTensors(
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
return intermediate_states
return weak_ref_tensors(output)
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode
):
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
assert batch_descriptor is not None
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
batch_descriptor=batch_descriptor
)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug(
"Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader().join_after_forward()
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
# output = weak_ref_tensors(output)
output = self.weak_ref_tensors_with_intermediate(output)
# here we always use weak ref for the output
# to save memory
# entry.output = weak_ref_tensors(output)
entry.output = self.weak_ref_tensors_with_intermediate(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}"
)
# Sync offloader before replay - ensures any external dependencies
# from pre-capture prefetches are satisfied.
get_offloader().sync_prev_onload()
entry.cudagraph.replay()
return entry.output

View File

@@ -0,0 +1,657 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib
import inspect
import os
import sys
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, TypeVar, overload
from unittest.mock import patch
import torch
import torch.nn as nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
CompilationMode,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.config.compilation import DynamicShapesType
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .monitor import start_monitoring_torch_compile
if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap
try:
from torch._dynamo.package import SourceInfo
except ImportError:
# Fallback for old versions not supporting
SourceInfo = Any
logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=nn.Module)
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.
If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.
If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr(cls, IGNORE_COMPILE_KEY, True)
return cls
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
def support_torch_compile(
cls: type[_T] | None = None,
*,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[type[_T]], type[_T]] | type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
- otherwise, it will raise an error.
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
`shape_invariants` is a function that gets compiled right before forward.
The function should have the torch._check calls that are needed to set
the relationships between different input sizes. For example:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
This enforces constraints on the symbolic shapes without hardcoding
specific values. It is needed for some models to avoid data dependent
errors.
"""
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile`
if not hasattr(cls, "forward"):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor,
torch.Tensor | None,
IntermediateTensors,
IntermediateTensors | None,
]:
inferred_dynamic_arg_dims[k] = 0
logger.debug(
("Inferred dynamic dimensions for forward method of %s: %s"),
cls,
list(inferred_dynamic_arg_dims.keys()),
)
if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly."
)
for k in inferred_dynamic_arg_dims:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
cls,
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
shape_invariants,
)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper
def _model_hash_key(fn: Callable[..., Any]) -> str:
import vllm
sha256_hash = hashlib.sha256()
sha256_hash.update(vllm.__version__.encode())
sha256_hash.update(fn.__qualname__.encode())
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
def _verify_source_unchanged(
source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content
file_contents = {}
for source in source_info.inlined_sources:
module = sys.modules[source.module]
file = inspect.getfile(module)
vllm_config.compilation_config.traced_files.add(file)
file_contents[file] = source.content
expected_checksum = _compute_code_hash_with_content(file_contents)
actual_checksum = _compute_code_hash(set(file_contents.keys()))
if expected_checksum != actual_checksum:
raise RuntimeError(
"Source code has changed since the last compilation. Recompiling the model."
)
def _support_torch_compile(
cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
"""
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
# support decorating multiple times
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWithNoGuardsWrapper
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
old_init = cls.__init__
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(
self: _T,
*,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
) -> None:
if vllm_config is None:
vllm_config = get_current_vllm_config()
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)
self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or _should_ignore_torch_compile(self.__class__)
or not enable_compile
)
if self.do_not_compile:
return
self._check_shape_invariants = shape_invariants
self.was_aot_compile_fn_loaded_from_disk = False
compilation_counter.num_models_seen += 1
self.compiled = False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
def _mark_dynamic_inputs(
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}."
)
if mark_unbacked_dims:
for k, dims in mark_unbacked_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
if is_torch_equal_or_newer("2.10.0"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
# enc-dec models where tensor shapes/types vary across invocations, preventing
# the capture of a single computational graph.
if is_forward_context_available() and get_forward_context().skip_compiled:
return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
if getattr(self, "aot_compiled_fn", None) is not None:
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
"""
When using torch.compile in AOT mode, we store the cache artifacts
under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
The {hash} contains all of the factors except for the source files
being traced through, because we don't actually know which source
files to check at this point (before dynamo runs).
On loading we will actually look at the source files being traced
through. If any source file have changed (compared with the
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
from .caching import aot_compile_hash_factors
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
"torch_aot_compile",
hash_key,
)
rank = self.vllm_config.parallel_config.rank
dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model")
try:
with (
set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f,
):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, "
"reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
if getattr(self, "aot_compiled_fn", None) is not None:
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
# Apply partition wrapper context for proper CUDA graph capture
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled:
assert (
not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager"
)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(
self,
ds_type,
*args,
**kwargs,
)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
original_code_object = self.original_code_object()
logger.debug("Start compiling function %s", original_code_object)
# we do not want tp delete the original code object entries since
# we depend on them now to look up cached compiled functions.
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.
# 1. the file containing the top-level forward function
self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(self_: Any) -> Any:
code = self_.f_code
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches = {}
try:
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
except AttributeError:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
# Prepare backed_size_oblivious config patch if needed
fx_config_patches = {}
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0"):
inductor_config_patches["assume_32bit_indexing"] = (
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
)
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
torch._inductor.config.patch(**inductor_config_patches),
):
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
if self.vllm_config.compilation_config.backend == "eager":
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True
return output
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self: type[_T]) -> None:
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return
assert (
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
)
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
try:
os.makedirs(self._aot_cache_dir, exist_ok=True)
# File saving should be atomic, so we will save to a temporary location
# first. Should be upstreamed to PyTorch 2.12 as well.
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
self.aot_compiled_fn.save_compiled_function(tmp_file)
os.replace(tmp_file, self._aot_compilation_path)
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
except Exception as e:
logger.warning(
"unable to save AOT compiled function to %s: %s",
self._aot_compilation_path,
e,
)
cls.__call__ = __call__
cls.save_aot_compiled_function = save_aot_compiled_function
return cls
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
) -> Generator[None, None, None]:
"""
Context manager to set/unset customized cudagraph partition wrappers.
If we're using Inductor-based graph partitioning, we currently have the
whole `fx.Graph` before Inductor lowering and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
capture and replay.
"""
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition
):
from torch._inductor.utils import CUDAGraphWrapperMetadata
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.platforms import current_platform
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
def customized_cudagraph_wrapper(
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
) -> Any:
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
),
)
torch._inductor.utils.set_customized_partition_wrappers(
customized_cudagraph_wrapper
)
yield
if (
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition
):
torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
context_manager = None
torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
global torch_compile_start_time
torch_compile_start_time = time.perf_counter()
compilation_config: CompilationConfig = vllm_config.compilation_config
path = vllm_config.compile_debug_dump_path()
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
import depyf
path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager
context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__()
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
compilation_config: CompilationConfig = vllm_config.compilation_config
total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(
"torch.compile takes %.2f s in total",
total_compile_time,
scope="local",
)
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled() -> None:
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
if not cudagraph_capturing_enabled:
raise RuntimeError(
"CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled."
)
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Generator
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
if node.op != "call_function":
return False
target = node.target
if isinstance(target, torch._ops.OpOverloadPacket):
# Example: "aten::add"
return target._qualified_op_name in splitting_ops
if isinstance(target, torch._ops.OpOverload):
# Example: "aten::add"
packet_name = target.name()
# Example: "aten::add.default"
op_overload_name = f"{packet_name}.{target._overloadname}"
return op_overload_name in splitting_ops or packet_name in splitting_ops
return False
@contextlib.contextmanager
def inductor_partition_rule_context(
splitting_ops: list[str] | None,
) -> Generator[None, None, None]:
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Args:
splitting_ops: List of operator names to partition on.
"""
if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
# Save current state before registering
saved_splitting_ops: list[str] = list(
torch._inductor.config.custom_should_partition_ops
)
torch._inductor.config.custom_should_partition_ops = splitting_ops
logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)
try:
yield
finally:
# Clear and restore previous state
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
logger.debug("Restored previous partition rules state.")

View File

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
FUSED_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
}
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
) -> None:
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, (
f"unsupported fusion scheme {self.quant_key}"
)
self.FUSED_OP = FUSED_OPS[self.quant_key]
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP, result=result, input=input, scale=scale
)
return at[1]
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self) -> None:
super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
is_sf_swizzled_layout=True,
)
return at[1], at[2]
def replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized(
self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale,
)
return at[1], at[2]
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass"
)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
)

View File

@@ -0,0 +1,862 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from importlib.util import find_spec
from types import ModuleType
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
import flashinfer.comm as _flashinfer_comm
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
_flashinfer_comm, "create_allreduce_fusion_workspace"
):
flashinfer_comm = _flashinfer_comm
except ImportError:
pass
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}
# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
90: {
2: 32, # 32MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 32, # 32MB
4: 4, # 4MB
8: 1, # 1MB
},
}
if flashinfer_comm is not None:
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
MiB = 1024 * 1024
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
curr_device = current_platform.get_device_capability()
device_capability = curr_device.to_int() if curr_device is not None else None
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type, unused-ignore]
{},
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
layout_code = None
# layout_code only supported by trtllm backend
if workspace.backend == "trtllm":
# in vllm we only support swizzled layout
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=allreduce_in,
workspace=workspace,
pattern=pattern_code,
launch_with_pdl=launch_with_pdl,
output=None,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
residual_in=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="flashinfer_trtllm_fused_allreduce_norm",
op_func=call_trtllm_fused_allreduce_norm,
mutates_args=[
"allreduce_in",
"residual",
"norm_out",
"quant_out",
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
world_size: int,
max_token_num: int = 1024,
) -> None:
self.world_size = world_size
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
return {
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
}
# TODO(luka): unify
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
with fused flashinfer implementation.
Applies to allreduce + rmsnorm before attn in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight)
return rms, allreduce_output
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# rms_result, allreduce_in
return allreduce[3], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# allreduce_in, residual
return allreduce[1], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output
return allreduce[4], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
return quant, res
def replacement(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual
return allreduce[4], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [input, quant_result, weight, input_global_scale, output_scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
def replacement(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output, output_scale
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], residual, quant_out_tuple[2]
def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual, output_scale
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
rank = get_tensor_model_parallel_rank()
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass"
)
return
max_size = config.compilation_config.pass_config.flashinfer_max_size(
self.tp_size
)
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size,
)
return
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
self.max_token_num = min(
self.max_token_num, config.scheduler_config.max_num_batched_tokens
)
logger.debug_once(
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
"Maximal number of tokens used by "
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
scope="global",
)
for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
max_token_num=self.max_token_num,
)
self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
if supports_quantization:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
with contextlib.suppress(Exception):
destroy_fi_ar_workspace()

View File

@@ -0,0 +1,374 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
) -> None:
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
is_sf_swizzled_layout=True,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
empty_i32(
128, round_up(self.num_heads * self.head_size // 16, 4)
), # output_scale
empty_fp32(1, 1), # input_scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype
)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype
)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern,
)

View File

@@ -0,0 +1,423 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(
input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
return torch.ops.aten._scaled_mm.default(
all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
return cutlass_scaled_mm[1]
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRotaryEmbedding(MatcherCustomOp):
def __init__(
self,
is_neox: bool,
head_size: int,
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
match_rocm_aiter: bool | None = None,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
if match_rocm_aiter is None:
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
super().__init__(enabled)
self.is_neox = is_neox
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_size
self.kv_size = self.num_kv_heads * self.head_size
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
elif match_rocm_aiter:
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
else:
self.rotary_op = ROTARY_OP
def inputs(self) -> list[torch.Tensor]:
positions = self.empty_int64(5)
query = self.empty(5, self.q_size)
key = self.empty(5, self.kv_size)
cos_sin_cache = self.empty(4096, self.rotary_dim)
return [positions, query, key, cos_sin_cache]
def forward_custom(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result = auto_functionalized(
self.rotary_op,
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
)
query_out = result[1]
key_out = result[2] if len(result) > 2 else None
return query_out, key_out
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result: tuple[torch.Tensor, torch.Tensor | None] = (
RotaryEmbedding.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox,
)
)
return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self.match_rocm_aiter = match_rocm_aiter
self._rmsnorm_op = RMS_ADD_OP
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight, residual)
_, result, residual = auto_functionalized(
self._rmsnorm_op,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
return result
class MatcherQuantFP8(MatcherCustomOp):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
is_tma_aligned: bool = False,
) -> None:
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
self.match_rocm_aiter = match_rocm_aiter
self.is_tma_aligned = is_tma_aligned
if match_rocm_aiter:
assert not quant_key.scale.group_shape.is_per_tensor(), (
"ROCm aiter fusion pass does not support per tensor quantization"
)
if quant_key.scale.group_shape.is_per_token():
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
else:
assert quant_key.scale.group_shape.col == 128, (
"ROCm aiter fusion pass currently supports "
"quantization operation with group_size 128"
)
if current_platform.is_fp8_fnuz():
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
else:
self.QUANT_OP = (
torch.ops.vllm.triton_per_token_group_quant_fp8.default
)
else:
assert quant_key in QUANT_OPS, (
f"unsupported quantization scheme {quant_key}"
)
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
tma_aligned_scales=self.is_tma_aligned,
compile_native=False,
)
def forward_rocm_aiter(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP( # type: ignore[no-any-return]
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, scale)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
# for tma_aligned, the scale must be passed to forward_custom
# tma_aligned fusion then matches by custom op arguments
if not self.is_tma_aligned:
assert scale is None
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None) -> None:
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 4)
return [input]
def forward_custom(
self,
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
return result[1]
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
return SiluAndMul.forward_native(x)

View File

@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern:
"""
Match the unfused sequence in attention blocks and replace with the fused op.
Unfused (conceptually):
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
qh = reshape(q, [-1, num_heads, head_dim])
kh = reshape(k, [-1, num_kv_heads, head_dim])
qn = rms_norm(qh, q_weight, eps)
kn = rms_norm(kh, k_weight, eps)
qf = reshape(qn, [-1, num_heads * head_dim])
kf = reshape(kn, [-1, num_kv_heads * head_dim])
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
return qf, kf, v
Fused replacement:
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
eps, q_weight, k_weight, cos_sin_cache, is_neox,
positions.view(-1))
return split(qkv, [qsz, kvsz, kvsz], -1)
"""
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=is_neox,
head_size=self.head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
positions = empty_i64(T)
q_weight = empty_bf16(1, self.head_dim)
k_weight = empty_bf16(1, self.head_dim)
if self.rope_flashinfer:
cos_sin_cache = empty_fp32(4096, self.head_dim)
else:
cos_sin_cache = empty_bf16(4096, self.head_dim)
return [
qkv,
positions,
q_weight,
k_weight,
cos_sin_cache,
]
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q path: view -> RMS -> view back to q.shape
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
qkv=qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.eps,
q_weight=q_weight,
k_weight=k_weight,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
position_ids=positions.view(-1),
)
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
QkNormRopePattern.wrap_trace_fn(
pm.fwd_only,
QkNormRopePattern.fx_view_to_reshape,
),
pm_pass,
)
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
)
dtype = config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.warning_once(
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
)
return
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
if len(attn_layers) == 0:
logger.warning_once(
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@@ -0,0 +1,643 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, NamedTuple
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
"""
Named tuple for identifying the type of RMSNorm + quant fusion.
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self) -> str:
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
)
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(
kFp8StaticTensorSym, False
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8StaticTensorSym, True
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, False
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
class RMSNormQuantPattern:
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
is_tma_aligned: bool = False,
) -> None:
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result
return at[1]
inputs = [
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result, residual
return at[1], at[2]
inputs = [
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.is_e8m0 = is_e8m0
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, residual, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=self.has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, scale
return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .rms_quant_fusion import (
FusedRMSQuantKey,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
class AiterRMSNormQuantPattern:
def __init__(
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
match_rocm_aiter=match_aiter_quant,
)
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1], result[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
AiterRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
for match_aiter_quant in [True, False]:
# Fuse aiter rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterFusedAddRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops import auto_functionalized
from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.attention import (
Attention,
get_attention_context,
)
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherRotaryEmbedding,
)
from .rms_quant_fusion import (
empty_bf16,
empty_i64,
)
logger = init_logger(__name__)
def fused_rope_and_unified_kv_cache_update_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
"""
This impl fetches the KV cache and slot mapping from the forward context,
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update(
attn_layer,
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def fused_rope_and_unified_kv_cache_update_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype)
direct_register_custom_op(
op_name="fused_rope_and_unified_kv_cache_update",
op_func=fused_rope_and_unified_kv_cache_update_impl,
mutates_args=["query", "key"],
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
)
class RopeReshapeKVCachePattern:
"""
This pattern matches the following unfused inplace ops:
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
and replaces it with the fused inplace op:
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
q, k, v, positions, cos_sin_cache, is_neox, layer_name
)
"""
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
def __init__(
self,
layer: Attention,
is_neox: bool,
) -> None:
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.num_kv_heads = layer.num_kv_heads
self.head_size = layer.head_size
self.head_size_v = layer.head_size_v
self.is_neox = is_neox
self.q_size = self.num_heads * self.head_size
self.k_size = self.num_kv_heads * self.head_size
self.v_size = self.num_kv_heads * self.head_size_v
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=self.is_neox,
head_size=self.head_size,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size)
return [
qkv,
positions,
cos_sin_cache,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=self.layer_name,
)
return results[0], results[1], results[2], v
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
gm = pm.fwd_only(*args, **kwargs)
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
)
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
"""
This pass fuses the rotary embedding and KV cache update operations
into a single fused kernel if available.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
This fusion eliminates the need for separate kernel launches and
intermediate memory operations between the RoPE and cache update steps.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rope_kv_cache_fusion_pass"
)
cc = config.compilation_config
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention)
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]:
RopeReshapeKVCachePattern(
layer=layer,
is_neox=is_neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass works best for the small-batch decode setting.
# For large-batch e.g. prefill, it is better to use two separate kernels
# since they are compute bound and the fused kernels require further tuning.
return compile_range.end <= self.max_token_num
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)

View File

@@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
logger = init_logger(__name__)
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
}
# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
}
def get_sequence_parallelism_threshold(
hidden_size: int,
tp_size: int,
element_size: int,
) -> int | None:
"""
Calculate the minimum token threshold for applying sequence parallelism.
Returns None if sequence parallelism should not be applied based on model size.
Branching logic based on device capability:
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
- If not, returns None (SP disabled for small models on this device)
- If yes, calculates threshold based on per-GPU size
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
(hidden_size * element_size)
"""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return None
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
# Only apply sequence parallelism for models meeting the size threshold
if hidden_size < min_hidden_size:
return None
MiB = 1024 * 1024
min_size = min_per_gpu_size_mb * MiB * tp_size
return int(min_size // (hidden_size * element_size))
def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0]
return wrapper
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
all_gather = self._all_gather(rmsnorm[0])
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rms, residual_out = self.rmsnorm_matcher(
all_reduce, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rms, residual_out = self.rmsnorm_matcher(
reduce_scatter, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
an RMSNorm (or RMSNorm and then Quantization) operation.
These patterns are replaced with a ReduceScatter operation, followed by
a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is:
Input -> AllReduce -> RMSNorm -> Output
becomes
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements,
it lays the groundwork for subsequent fusion passes, such as
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Get min_token_num threshold
# Read min_token_num from config (calculated during config init)
self.min_token_num = None
if config.model_config is not None:
pass_config = config.compilation_config.pass_config
self.min_token_num = pass_config.sp_min_token_num
if self.min_token_num is not None:
# Take the min to avoid exceeding max_num_batched_tokens
max_batched = config.scheduler_config.max_num_batched_tokens
if max_batched is not None:
self.min_token_num = min(self.min_token_num, max_batched)
logger.debug_once(
f"Sequence parallelism min token threshold: {self.min_token_num}",
scope="global",
)
# Used to clean up redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Determines if sequence parallelism should be applied for the given
compile range.
SP is only beneficial for larger batch sizes where the communication
overhead is amortized. For small batches, the overhead of splitting
and gathering tensors across TP ranks outweighs the benefits.
Returns False (SP disabled) when:
- Using piecewise compilation with non-concrete or TP-indivisible sizes
- min_token_num is None (SP disabled for this device/config)
- The compile range starts below the minimum token threshold
"""
# For piecewise compilation (not using inductor graph partition),
# we need concrete sizes that are divisible by TP for correct splitting
if (
not self.compilation_config.use_inductor_graph_partition
and self.compilation_config.splitting_ops
):
tp_size = get_tensor_model_parallel_world_size()
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
return False
# min_token_num is None when SP is disabled for this device/config
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
if self.min_token_num is None:
return False
# Only apply SP when batch size meets the minimum threshold
return compile_range.start >= self.min_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.node import Target
def is_func(node: fx.Node, target: Target) -> bool:
return bool(node.op == "call_function" and node.target == target)
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools
import hashlib
import inspect
import json
import types
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
if TYPE_CHECKING:
from vllm.config.utils import Range
from torch._inductor.custom_graph_pass import CustomGraphPass
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
class PassContext:
def __init__(self, compile_range: Range):
self.compile_range: Range = compile_range
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(compile_range: Range) -> Generator[None, None, None]:
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(compile_range)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass): # type: ignore[misc]
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> str:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)
@staticmethod
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_range(self, compile_range: Range) -> bool:
return True
class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph) -> None:
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from .fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.scatter_split_replace import ScatterSplitReplacementPass
from .utility.split_coalescing import SplitCoalescingPass
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def __init__(self) -> None:
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph) -> None:
VllmInductorPass.dump_prefix = 0 # reset dump index
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.eliminate_noops:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sp:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.fuse_gemm_comms:
self.passes += [AsyncTPPass(config)]
if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_rope_kvcache:
self.passes += [SplitCoalescingPass(config)]
self.passes += [ScatterSplitReplacementPass(config)]
self.passes += [RopeKVCacheFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass) -> None:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self) -> str:
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
passes = []
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes:
passes.append(pass_.uuid())
passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
state["passes"] = passes
return InductorPass.hash_dict(state)

View File

@@ -0,0 +1,301 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class FixFunctionalizationPass(VllmInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif (
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
and at_target
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
):
mutated_args = {
1: "query",
2: "key",
}
self.defunctionalize(graph, node, mutated_args=mutated_args)
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
elif (
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
and at_target
== torch.ops.vllm.function_with_mutated_args_and_return.default
):
mutated_args = {1: "x"}
self.defunctionalize(graph, node, mutated_args=mutated_args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
) -> None:
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
# as well as mutated args at getitem[1:...]
if idx == 0:
assert idx not in mutated_args, (
f"result at getitem[0] should not be in mutated_args for {node}"
)
continue
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
)
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
fn_node = graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
fn_node = graph.call_function(function, args=args)
# If the function returns a value as well as mutating args inplace,
# the functionalized node will have a getitem[0] user that holds this value
# Replace getitem[0] user of the auto-functionalized node
# with the new defunctionalized node directly if it exists
users = self.getitem_users(node)
if 0 in users:
user = users[0]
user.replace_all_uses_with(fn_node)
self._remove(user)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 3:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# remove reshape/slice if it produces the original shape
if is_func(node, torch.ops.aten.reshape.default) or is_func(
node, torch.ops.aten.slice.Tensor
):
input = node.args[0]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if self.all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Shape comparison helpers ----------------------
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
assignment if there are no users for the inplace tensor written to by
the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs
that are split+getitem outputs of a single qkv tensor.
When functionalized, we fetch the rotated query and key from the functionalized op
using `getitem` calls. However, we also write to the qkv tensor inplace using a
`slice_scatter`, then split the inplace tensor to get the output tensors again.
Instead, if the inplace tensor has no subsequent users, we can just replace the
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but
writing a custom pass for it before defunctionalization allows matching against the
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
"""
import operator
import torch
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class ScatterSplitReplacementPass(VllmInductorPass):
"""Replace getitem+slice_scatter+split nodes with a single getitem when
the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
q = operator.getitem(split_with_sizes_2, 0)
k = operator.getitem(split_with_sizes_2, 1)
v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
v = operator.getitem(split_with_sizes_1, 2)
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
kwargs = node.kwargs
at_target = node.args[0]
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
for user in node.users:
if is_func(user, operator.getitem):
getitem_nodes[user.args[1]] = user
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of a qkv tensor.
# While functionalized, results at [1] and [2] are scattered
# back into qkv, then split again to get query and key.
# If the inplace tensor has no other users, we can replace
# the slice_scatter+split nodes with the original results.
for user in getitem_nodes[1].users:
slice_scatter_1_node = user
if not is_func(
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
):
continue
for user in getitem_nodes[2].users:
slice_scatter_2_node = user
if not is_func(
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
):
continue
for user in slice_scatter_2_node.users:
split_node = user
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
continue
split_getitem_users = {}
for user in split_node.users:
if is_func(user, operator.getitem):
split_getitem_users[user.args[1]] = user
# Replace query node
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
graph.erase_node(split_getitem_users[0])
# Replace key node
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
graph.erase_node(split_getitem_users[1])
# Redirect value node to original qkv tensor
split_getitem_users[2].replace_input_with(split_node, query.args[0])
# Erase unused nodes
graph.erase_node(split_node)
graph.erase_node(slice_scatter_2_node)
graph.erase_node(slice_scatter_1_node)
count += 1
logger.debug("Eliminated %d slice_scatter+split nodes", count)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
graph may contain multiple ``split_with_sizes`` calls on the same tensor
that CSE fails to merge. This pass detects and replaces the duplicates
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
see a single split node with all users attached.
See also:
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
"""
import operator
import torch
from torch import fx
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar
import regex as re
import torch
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .inductor_pass import InductorPass
logger = init_logger(__name__)
@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
dump_prefix: ClassVar[int | None] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
# Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device: str | None = (
config.device_config.device if config.device_config else None
)
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
)
def begin(self) -> None:
self._start_time = time.perf_counter_ns()
def end_and_log(self) -> None:
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
)
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return str(
self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compile_debug_dump_path()
if not debug_dump_path:
return
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils.system_utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
)
with file_path.open("w") as f:
print(
f"# This file was produced by VllmPatternMatcherPass."
f"dump_patterns for {self.pass_name}.\n"
f"# It does its best to produce valid-Python-looking code but"
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join(
[f"def pattern_{i}():"]
+ [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
]
+ [f"return {out_node}"]
).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config)
self.name = name
def __call__(self, graph: torch.fx.Graph) -> None:
self.dump_graph(graph, self.name)

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import io
import json
import pickle
import time
from collections.abc import Callable
from pickle import Pickler
from typing import Any
import torch._functorch.config
import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclasses.dataclass
class RangeEntry:
compile_range: Range
compiled: bool = False
runnable: Callable[..., Any] = None # type: ignore
class PiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule | None,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
submod_name: str = "",
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
This class supports two mutually exclusive modes:
1. Compilation (graph is set, compiled_runnables is None):
Used during initial compilation when we have the FX graph
and need to compile it for each shape range.
2. Precompilation (graph is None, compiled_runnables is set):
Used when loading from cache/AOT artifacts where we already
have pre-compiled callables and don't need the original graph.
Exactly one of graph or compiled_runnables must be provided.
"""
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
"exactly one of graph and compiled_runnables should be set."
)
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.submod_name = submod_name
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
self.is_encoder_compilation = vllm_backend.is_encoder
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
self.returns_tuple = returns_tuple
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
if self.compile_sizes is not None:
for size in self.compile_sizes:
if isinstance(size, str):
assert size == "cudagraph_capture_sizes"
raise NotImplementedError(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else:
assert isinstance(size, int)
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
from vllm.compilation.backends import _on_compilation_complete_callback
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(
self, compiled_graph: Callable[..., Any]
) -> Callable[..., Any]:
def compiled_graph_wrapper(*args: Any) -> Any:
graph_output = compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
time_before_saving = time.perf_counter()
self.vllm_backend.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj: object) -> Any:
if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle()
return pickle.loads, (
pickle.dumps(
obj,
),
)
return NotImplemented
def serialize(fn: Callable[..., Any]) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize()
f = io.BytesIO()
StandaloneCompiledArtifactsPickler(f).dump(entry)
result = f.getvalue()
return result
out = {}
for range_key, entry in self.range_entries.items():
if not entry.compiled:
logger.debug(
"entry with range %s not compiled, so cannot get its bytes",
range_key,
)
continue
if hasattr(entry.runnable, "serialize"):
out[str(range_key)] = serialize(entry.runnable)
return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
assert self.graph is not None
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size = (
self.compile_sizes is not None and compile_range.start in self.compile_sizes
)
subgraph_index = self.piecewise_compile_index
submod_name = self.submod_name
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_piecewise_compile_start",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"piecewise_index": subgraph_index,
"submod_name": submod_name,
"total_piecewise_compiles": self.total_piecewise_compiles,
"compile_range_start": compile_range.start,
"compile_range_end": compile_range.end,
"is_single_size": compile_range.is_single_size(),
"is_cudagraph_capture_size": is_cudagraph_size,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if not self._graph_logged:
self._graph_logged = True
assert self.graph is not None
trace_structured(
"graph_dump",
metadata_fn=lambda: {
"name": f"vllm_{submod_name}",
},
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled:
if self.compiled_runnables is not None:
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if self.compile_sizes is None:
return None
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args: Any) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)

321
vllm/compilation/wrapper.py Normal file
View File

@@ -0,0 +1,321 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
from abc import abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any, ParamSpec, TypeVar
import torch
import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
R = TypeVar("R")
P = ParamSpec("P")
def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
@contextmanager
def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches.
This manager:
1. Sets higher dynamo cache limits for compilation. (Needed for
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
3. Patches out add_torch_function_mode_stack_guard to skip
TORCH_FUNCTION_MODE_STACK guards.
4. Restores everything when compilation completes
"""
# Save original values
original_global_state_guard = (
torch._C._dynamo.guards.GuardManager.add_global_state_guard
)
original_torch_function_mode_stack_guard = (
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
)
original_cache_size = torch._dynamo.config.cache_size_limit
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
try:
# Set higher cache limits for compilation
torch._dynamo.config.cache_size_limit = 2048
torch._dynamo.config.accumulated_cache_size_limit = 8192
# Patch guard manager
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
_noop_add_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
_noop_add_torch_function_mode_stack_guard
)
yield
finally:
# Restore original values
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
original_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
original_torch_function_mode_stack_guard
)
torch._dynamo.config.cache_size_limit = original_cache_size
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
class TorchCompileWithNoGuardsWrapper:
"""
A wrapper class for torch.compile, it ensures that all guards are dropped
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
When guards are dropped, the first time __call__ is invoked, a single
compilation is triggered. Dynamo should never be traced again after that
since we drop all guards.
"""
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Any:
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self) -> None:
self.compiled = False
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = {}
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED:
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
else:
msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg)
with aot_context:
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
self._compiled_bytecode: CodeType | None = None
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self._compiled_callable, "aot_compile"):
raise RuntimeError(
"aot_compile is not supported by the current configuration. "
"Please make sure torch.compile is enabled with the latest "
f"version of PyTorch (current using torch: {torch.__version__})"
)
return self._compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if envs.VLLM_USE_BYTECODE_HOOK:
if (
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
return self._compiled_callable(*args, **kwargs)
if not self._compiled_bytecode:
# Make sure a compilation is triggered by clearing dynamo
# cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else:
with self._dispatch_to_compiled_code():
return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else:
ctx = (
nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
def original_code_object(self) -> CodeType:
"""Return the original code object of the forward method."""
return self.__class__.forward.__code__
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object():
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while frame and frame.f_back:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self._compiled_bytecode = new_code
path = self.vllm_config.compile_debug_dump_path()
if path:
decompiled_file = path / "transformed_code.py"
if not decompiled_file.exists():
try:
# usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
except Exception:
pass
if (
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and "update" in new_code.co_names
):
import depyf
src = depyf.decompile(new_code)
msg = (
"Assigning / modifying buffers of nn.Module during forward pass is not "
"allowed when using cudagraph inside the compiler because it will "
"cause silent errors. Please use eager mode or fix the code. The "
"following code contains clues about which buffer is being modified "
f"(please search for the usage of the function `update`):\n{src}"
)
raise RuntimeError(msg)
@contextmanager
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
# noqa: E501
"""
Context manager to dispatch to internally compiled code for torch<2.8.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa: E501 line too long
original = self.original_code_object()
assert self._compiled_bytecode is not None
self.__class__.forward.__code__ = self._compiled_bytecode
try:
yield
finally:
self.__class__.forward.__code__ = original