Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
1131
vllm/compilation/backends.py
Normal file
1131
vllm/compilation/backends.py
Normal file
File diff suppressed because it is too large
Load Diff
57
vllm/compilation/base_static_graph.py
Normal file
57
vllm/compilation/base_static_graph.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
516
vllm/compilation/caching.py
Normal file
516
vllm/compilation/caching.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.compiler_interface import get_inductor_factors
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
except ImportError:
|
||||
SerializableCallable = object
|
||||
|
||||
assert isinstance(SerializableCallable, type)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StandaloneCompiledArtifacts:
|
||||
"""Storage for standalone compiled artifacts with content-based deduplication.
|
||||
|
||||
Deduplication works via a two-level indirection:
|
||||
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
|
||||
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
|
||||
|
||||
When inserting, we compute the SHA256 hash of the bytes. If the hash
|
||||
already exists in `submodule_bytes_store`, we reuse the existing entry
|
||||
rather than storing duplicate bytes. This is common because submodules
|
||||
often compile to identical artifacts (e.g., identical transformer layers
|
||||
split on attn)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# dict from submodule name to byte hash
|
||||
self.submodule_bytes: dict[str, str] = {}
|
||||
# dict from byte hash to bytes
|
||||
self.submodule_bytes_store: dict[str, bytes] = {}
|
||||
# dict from byte hash to loaded module
|
||||
self.loaded_submodule_store: dict[str, Any] = {}
|
||||
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(entry)
|
||||
hex_digest = hasher.hexdigest()
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
|
||||
if hex_digest not in self.submodule_bytes_store:
|
||||
self.submodule_bytes_store[hex_digest] = entry
|
||||
logger.debug(
|
||||
"inserting new artifact for submod %s with shape %s "
|
||||
"(%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"reusing existing cache artifact for submod %s "
|
||||
"with shape %s (%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
|
||||
def get(self, submod_name: str, shape: str) -> bytes:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.submodule_bytes_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def get_loaded(self, submod_name: str, shape: str) -> Any:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.loaded_submodule_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def size_bytes(self) -> int:
|
||||
return sum(len(entry) for entry in self.submodule_bytes_store.values())
|
||||
|
||||
def num_artifacts(self) -> int:
|
||||
return len(self.submodule_bytes_store)
|
||||
|
||||
def num_entries(self) -> int:
|
||||
return len(self.submodule_bytes)
|
||||
|
||||
def submodule_names(self) -> list[str]:
|
||||
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
|
||||
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
|
||||
return list(dict.fromkeys(names))
|
||||
|
||||
def load_all(self) -> None:
|
||||
import concurrent.futures
|
||||
|
||||
# check already loaded
|
||||
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
|
||||
return
|
||||
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
|
||||
entry = pickle.loads(entry_bytes)
|
||||
return AOTCompiledArtifact.deserialize(entry)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
entries = list(self.submodule_bytes_store.values())
|
||||
loaded_entries = list(executor.map(_load_entry, entries))
|
||||
|
||||
for i, k in enumerate(self.submodule_bytes_store.keys()):
|
||||
self.loaded_submodule_store[k] = loaded_entries[i]
|
||||
|
||||
logger.debug("loaded all %s submodules", self.num_artifacts())
|
||||
|
||||
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
|
||||
return {
|
||||
"submodule_bytes": self.submodule_bytes,
|
||||
"submodule_bytes_store": self.submodule_bytes_store,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
|
||||
self.submodule_bytes = state["submodule_bytes"]
|
||||
self.submodule_bytes_store = state["submodule_bytes_store"]
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
inputs to the compiled function and return the result.
|
||||
It also implements a serialization interface to support PyTorch's precompile
|
||||
with custom backend, so that we can save and load the compiled function on
|
||||
disk. There's no need to wrap around the compiled function if we don't want
|
||||
to serialize them in particular cases.
|
||||
Right now serialization for the custom backend is done via
|
||||
serializing the Dynamo fx graph plus example inputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[Any],
|
||||
prefix: str,
|
||||
optimized_call: Callable[..., Any],
|
||||
is_encoder: bool = False,
|
||||
vllm_backend: Any | None = None,
|
||||
sym_tensor_indices: list[int] | None = None,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
self.example_inputs = example_inputs
|
||||
self.prefix = prefix
|
||||
self.optimized_call = optimized_call
|
||||
self.is_encoder = is_encoder
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
if sym_input is not None:
|
||||
self.shape_env = sym_input.node.shape_env
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
import sympy
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
def _graph_reducer_override(
|
||||
self: GraphPickler, obj: Any
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
and hasattr(obj, "_torch_unpickler")
|
||||
):
|
||||
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||
if isinstance(obj, FakeTensorMode):
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
if state.get("sym_tensor_indices"):
|
||||
# put tensor inputs on meta device since their data
|
||||
# isn't needed, yet we need the meta for make_copy_and_call
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
else:
|
||||
# mask off all tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
if compiled_fn.vllm_backend:
|
||||
(
|
||||
standalone_compile_artifacts,
|
||||
sym_shape_indices_map,
|
||||
returns_tuple_map,
|
||||
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
|
||||
state["standalone_compile_artifacts"] = standalone_compile_artifacts
|
||||
state["sym_shape_indices_map"] = sym_shape_indices_map
|
||||
state["returns_tuple_map"] = returns_tuple_map
|
||||
return pickle.dumps(state)
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||
from torch._guards import TracingContext, tracing
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
|
||||
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
|
||||
returns_tuple_map = state.pop("returns_tuple_map", {})
|
||||
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
assert standalone_compile_artifacts is not None
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
# Fall back to standard VllmBackend
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
vllm_backend: VllmBackend = VllmBackend(
|
||||
get_current_vllm_config(), state["prefix"], is_encoder
|
||||
)
|
||||
|
||||
def optimized_call(*example_inputs: Any) -> Any:
|
||||
"""
|
||||
On the first run of the optimized call, we rerun the compiler
|
||||
backend which should result in a cache hit. After the backend
|
||||
call returns, we just do a one-time replacement of the optimized
|
||||
call with the compiled function, so that subsequent calls are on
|
||||
the AOT compiled path.
|
||||
"""
|
||||
compile_inputs = [
|
||||
inp if inp is not None else example_inputs[i]
|
||||
for i, inp in enumerate(fn.example_inputs)
|
||||
]
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
fn.optimized_call = vllm_backend(
|
||||
state["graph_module"], compile_inputs
|
||||
).optimized_call
|
||||
return fn.optimized_call(*example_inputs)
|
||||
|
||||
fn = cls(**state, optimized_call=optimized_call)
|
||||
return fn
|
||||
|
||||
@property
|
||||
def co_name(self) -> Literal["VllmSerializableFunction"]:
|
||||
"""
|
||||
Used for depyf debugging.
|
||||
"""
|
||||
return "VllmSerializableFunction"
|
||||
|
||||
|
||||
def reconstruct_serializable_fn_from_mega_artifact(
|
||||
state: dict[str, Any],
|
||||
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
|
||||
vllm_config: VllmConfig,
|
||||
sym_shape_indices_map: dict[str, list[int]],
|
||||
returns_tuple_map: dict[str, bool],
|
||||
) -> "VllmSerializableFunction":
|
||||
"""Construct a VllmSerializableFunction from cached inductor artifacts.
|
||||
|
||||
This function reconstructs a callable model from pre-compiled inductor
|
||||
artifacts without re-running the compilation. It:
|
||||
1. Loads all cached artifacts
|
||||
2. Builds compiled callables for each submodule/shape
|
||||
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
|
||||
4. Wraps with cudagraph if needed
|
||||
5. Returns the final VllmSerializableFunction
|
||||
|
||||
Note: This function shares similar logic with PiecewiseCompileInterpreter
|
||||
in backends.py. Both create PiecewiseBackend instances and wrap them with
|
||||
cudagraph. The key difference is:
|
||||
- this function: PiecewiseBackend receives pre-compiled runnables
|
||||
(compiled_runnables is set, graph is None)
|
||||
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
|
||||
to compile (graph is set, compiled_runnables is None)
|
||||
|
||||
If modifying the backend creation/wrapping logic, consider updating both.
|
||||
|
||||
Args:
|
||||
state: Deserialized state dict containing graph_module, example_inputs,
|
||||
prefix, sym_tensor_indices, is_encoder, etc.
|
||||
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
|
||||
pre-compiled artifacts for each submodule/shape combination.
|
||||
vllm_config: The vLLM configuration.
|
||||
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
|
||||
returns_tuple_map: Mapping from submod_name to returns_tuple.
|
||||
|
||||
Returns:
|
||||
A VllmSerializableFunction that can be called directly.
|
||||
"""
|
||||
from vllm.compilation.backends import (
|
||||
VllmBackend,
|
||||
make_copy_and_call,
|
||||
wrap_with_cudagraph_if_needed,
|
||||
)
|
||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
||||
|
||||
prefix = state["prefix"]
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
split_gm = state["graph_module"]
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
standalone_compile_artifacts.load_all()
|
||||
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||
|
||||
for cache_key in standalone_compile_artifacts.submodule_bytes:
|
||||
submod_name, shape_str = cache_key.rsplit("_", 1)
|
||||
compiled_callables.setdefault(submod_name, {})[shape_str] = (
|
||||
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
|
||||
)
|
||||
|
||||
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
|
||||
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
|
||||
os.makedirs(dummy_cache_dir, exist_ok=True)
|
||||
vllm_backend.compiler_manager.initialize_cache(
|
||||
cache_dir=dummy_cache_dir,
|
||||
disable_cache=True,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# spot check that cached submodules exist in the graph structure
|
||||
graph_children = {name for name, _ in split_gm.named_children()}
|
||||
missing = set(submod_names) - graph_children
|
||||
assert not missing, (
|
||||
f"artifacts reference submodules not in graph: {missing}. "
|
||||
f"graph has: {sorted(graph_children)}"
|
||||
)
|
||||
|
||||
for i, submod_name in enumerate(submod_names):
|
||||
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
|
||||
|
||||
sym_shape_indices = sym_shape_indices_map[submod_name]
|
||||
returns_tuple = returns_tuple_map[submod_name]
|
||||
runnables = compiled_callables[submod_name]
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
graph=None, # not needed for cached artifacts
|
||||
vllm_config=vllm_config,
|
||||
piecewise_compile_index=i,
|
||||
total_piecewise_compiles=len(submod_names),
|
||||
sym_shape_indices=sym_shape_indices,
|
||||
vllm_backend=vllm_backend,
|
||||
returns_tuple=returns_tuple,
|
||||
compiled_runnables=runnables,
|
||||
)
|
||||
|
||||
is_first = i == 0
|
||||
is_last = i == len(submod_names) - 1
|
||||
wrapped_backend = wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend,
|
||||
vllm_config,
|
||||
compilation_config,
|
||||
is_first,
|
||||
is_last,
|
||||
)
|
||||
|
||||
split_gm.__dict__[submod_name] = wrapped_backend
|
||||
logger.debug(
|
||||
"Replaced submodule %s with piecewise backend from cache",
|
||||
submod_name,
|
||||
)
|
||||
|
||||
if compilation_config.cudagraph_copy_inputs:
|
||||
sym_tensor_indices = state["sym_tensor_indices"]
|
||||
input_buffers = [
|
||||
torch.empty_like(
|
||||
state["example_inputs"][idx], device=vllm_config.device_config.device
|
||||
)
|
||||
for idx in sym_tensor_indices
|
||||
]
|
||||
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
|
||||
else:
|
||||
optimized_call = split_gm
|
||||
|
||||
fn = VllmSerializableFunction(
|
||||
**state,
|
||||
optimized_call=optimized_call,
|
||||
vllm_backend=None,
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = hash_factors(envs.compile_factors())
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. inductor factors if applicable
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
factors.extend(get_inductor_factors())
|
||||
|
||||
return factors
|
||||
|
||||
|
||||
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
|
||||
hash_content = []
|
||||
for filepath, content in items:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
result: str = safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
return result
|
||||
|
||||
|
||||
def _compute_code_hash(files: set[str]) -> str:
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
|
||||
)
|
||||
file_contents = {}
|
||||
for filepath in files:
|
||||
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||
if not os.path.isfile(filepath):
|
||||
file_contents[filepath] = ""
|
||||
else:
|
||||
with open(filepath) as f:
|
||||
file_contents[filepath] = f.read()
|
||||
return _compute_code_hash_with_content(file_contents)
|
||||
660
vllm/compilation/compiler_interface.py
Normal file
660
vllm/compilation/compiler_interface.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
|
||||
or a range [5, 8].
|
||||
Right now we only support one variable in ranges for all inputs,
|
||||
which is the batchsize (number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
def is_compile_cache_enabled(
|
||||
vllm_additional_inductor_config: dict[str, Any],
|
||||
) -> bool:
|
||||
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
|
||||
"force_disable_caches", False
|
||||
)
|
||||
|
||||
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
|
||||
# with torch.compiler.config.force_disable_caches when minimum PyTorch
|
||||
# version reaches 2.10
|
||||
return (
|
||||
not envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
and not torch._inductor.config.force_disable_caches
|
||||
and not vllm_inductor_config_disable_cache
|
||||
)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_graph"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
|
||||
"is enabled but PyTorch version does not support 'aot' "
|
||||
"parameter in standalone_compile. This requires PyTorch "
|
||||
"2.10.0+. Falling back to non-AOT mode."
|
||||
)
|
||||
|
||||
compile_kwargs = {
|
||||
"dynamic_shapes": dynamic_shapes,
|
||||
"options": {
|
||||
"config_patches": current_config,
|
||||
},
|
||||
}
|
||||
|
||||
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
|
||||
# only add 'aot' parameter if both supported and enabled...
|
||||
# this will set bundled_autograd_cache
|
||||
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
# Inductor's pre-grad passes don't do anything for vLLM.
|
||||
# The pre-grad passes get run even on cache-hit and negatively impact
|
||||
# vllm cold compile times by O(1s)
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx:
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
assert isinstance(compiled_graph, AOTCompiledArtifact)
|
||||
assert hasattr(compiled_graph, "serialize")
|
||||
# just return the compiled graph and a key
|
||||
# since we can serialize the bytes using to_bytes
|
||||
# and reload it using the key when reading
|
||||
return compiled_graph, None
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
|
||||
def is_saveable_2_10(compiled_artifact):
|
||||
# can just use compiled_artifact.is_saveable in 2.11
|
||||
if compiled_artifact._artifacts is None:
|
||||
return False
|
||||
_, cache_info = compiled_artifact._artifacts
|
||||
return len(cache_info.aot_autograd_artifacts) == 1
|
||||
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if not is_saveable_2_10(compiled_graph):
|
||||
raise RuntimeError(
|
||||
"The compiled artifact is not serializable. This usually means "
|
||||
"that the model code has something that is not serializable "
|
||||
"by torch.compile in it. You can fix this by either "
|
||||
"figuring out what is not serializable and rewriting it, "
|
||||
"filing a bug report, "
|
||||
"or suppressing this error by "
|
||||
"disabling vLLM's compilation cache via "
|
||||
"VLLM_DISABLE_COMPILE_CACHE=1 "
|
||||
"(this will greatly increase vLLM server warm start times)."
|
||||
)
|
||||
compiled_graph.save(path=path, format=self.save_format)
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a subdirectory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# Turn off the checks if we disable the compilation cache.
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. "
|
||||
)
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph"
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different mode of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config() -> None:
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
50
vllm/compilation/counter.py
Normal file
50
vllm/compilation/counter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
332
vllm/compilation/cuda_graph.py
Normal file
332
vllm/compilation/cuda_graph.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CUDAGraphStat:
|
||||
num_unpadded_tokens: int
|
||||
num_padded_tokens: int
|
||||
num_paddings: int
|
||||
runtime_mode: str
|
||||
|
||||
|
||||
class CUDAGraphLogging:
|
||||
"""Aggregate and log cudagraph metrics"""
|
||||
|
||||
COLUMN_HEADERS = [
|
||||
"Unpadded Tokens",
|
||||
"Padded Tokens",
|
||||
"Num Paddings",
|
||||
"Runtime Mode",
|
||||
"Count",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
|
||||
) -> None:
|
||||
self.reset()
|
||||
self.cg_mode = str(cg_mode)
|
||||
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||
|
||||
self.settings_header = (
|
||||
"**CUDAGraph Config Settings:**\n\n"
|
||||
f"- Mode: {self.cg_mode}\n"
|
||||
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
|
||||
"**CUDAGraph Stats:**\n\n"
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.stats: list[CUDAGraphStat] = []
|
||||
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
|
||||
self.stats.append(cudagraph_stat)
|
||||
|
||||
def generate_metric_table(self) -> str:
|
||||
stats_counts = Counter(self.stats)
|
||||
|
||||
# Convert stats to rows of strings, in descending order of observed frequencies
|
||||
rows = []
|
||||
for stat, count in sorted(
|
||||
stats_counts.items(), key=lambda item: item[1], reverse=True
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
str(stat.num_unpadded_tokens),
|
||||
str(stat.num_padded_tokens),
|
||||
str(stat.num_paddings),
|
||||
stat.runtime_mode,
|
||||
str(count),
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate column widths (max of header and data)
|
||||
col_widths = []
|
||||
for i, header_text in enumerate(self.COLUMN_HEADERS):
|
||||
max_width = len(header_text)
|
||||
for row in rows:
|
||||
max_width = max(max_width, len(row[i]))
|
||||
col_widths.append(max_width)
|
||||
|
||||
table_header_list = [
|
||||
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
|
||||
]
|
||||
table_header = "| " + " | ".join(table_header_list) + " |\n"
|
||||
|
||||
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
|
||||
|
||||
# Create data rows with proper alignment
|
||||
data_rows = []
|
||||
for row in rows:
|
||||
formatted_row = [
|
||||
str(val).ljust(width) for val, width in zip(row, col_widths)
|
||||
]
|
||||
data_rows.append("| " + " | ".join(formatted_row) + " |")
|
||||
|
||||
return (
|
||||
self.settings_header
|
||||
+ table_header
|
||||
+ table_separator
|
||||
+ "\n".join(data_rows)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
|
||||
if not self.stats:
|
||||
return
|
||||
log_fn(self.generate_metric_table())
|
||||
self.reset()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: torch.cuda.CUDAGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
) -> None:
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable[..., Any]:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def weak_ref_tensors_with_intermediate(self, output):
|
||||
if isinstance(output, IntermediateTensors):
|
||||
intermediate_states = IntermediateTensors(
|
||||
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
|
||||
return intermediate_states
|
||||
return weak_ref_tensors(output)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug(
|
||||
"Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(
|
||||
cudagraph,
|
||||
pool=self.graph_pool,
|
||||
stream=current_stream(),
|
||||
):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
# Join offloader's copy stream after forward to avoid
|
||||
# unjoined stream error. The last layer's start_prefetch
|
||||
# forks copy_stream, but wait_prefetch only happens in
|
||||
# the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
# Sync offloader before replay - ensures any external dependencies
|
||||
# from pre-capture prefetches are satisfied.
|
||||
get_offloader().sync_prev_onload()
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
657
vllm/compilation/decorators.py
Normal file
657
vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only added on nightly/2.10 so wrap
|
||||
try:
|
||||
from torch._dynamo.package import SourceInfo
|
||||
except ImportError:
|
||||
# Fallback for old versions not supporting
|
||||
SourceInfo = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=nn.Module)
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: type[_T] | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[type[_T]], type[_T]] | type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
|
||||
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
|
||||
`shape_invariants` is a function that gets compiled right before forward.
|
||||
The function should have the torch._check calls that are needed to set
|
||||
the relationships between different input sizes. For example:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
This enforces constraints on the symbolic shapes without hardcoding
|
||||
specific values. It is needed for some models to avoid data dependent
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
|
||||
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
|
||||
# to avoid too much indentation for `_support_torch_compile`
|
||||
if not hasattr(cls, "forward"):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
IntermediateTensors,
|
||||
IntermediateTensors | None,
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(
|
||||
("Inferred dynamic dimensions for forward method of %s: %s"),
|
||||
cls,
|
||||
list(inferred_dynamic_arg_dims.keys()),
|
||||
)
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly."
|
||||
)
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls,
|
||||
inferred_dynamic_arg_dims,
|
||||
mark_unbacked_dims,
|
||||
enable_if,
|
||||
shape_invariants,
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _model_hash_key(fn: Callable[..., Any]) -> str:
|
||||
import vllm
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(vllm.__version__.encode())
|
||||
sha256_hash.update(fn.__qualname__.encode())
|
||||
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _verify_source_unchanged(
|
||||
source_info: "SourceInfo", vllm_config: VllmConfig
|
||||
) -> None:
|
||||
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||
|
||||
file_contents = {}
|
||||
for source in source_info.inlined_sources:
|
||||
module = sys.modules[source.module]
|
||||
file = inspect.getfile(module)
|
||||
vllm_config.compilation_config.traced_files.add(file)
|
||||
file_contents[file] = source.content
|
||||
expected_checksum = _compute_code_hash_with_content(file_contents)
|
||||
actual_checksum = _compute_code_hash(set(file_contents.keys()))
|
||||
if expected_checksum != actual_checksum:
|
||||
raise RuntimeError(
|
||||
"Source code has changed since the last compilation. Recompiling the model."
|
||||
)
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: type[_T],
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(
|
||||
self: _T,
|
||||
*,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch
|
||||
# it
|
||||
sig = inspect.signature(old_init)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
self.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
self.was_aot_compile_fn_loaded_from_disk = False
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
|
||||
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(
|
||||
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
|
||||
# enc-dec models where tensor shapes/types vary across invocations, preventing
|
||||
# the capture of a single computational graph.
|
||||
if is_forward_context_available() and get_forward_context().skip_compiled:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
"""
|
||||
When using torch.compile in AOT mode, we store the cache artifacts
|
||||
under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
|
||||
The {hash} contains all of the factors except for the source files
|
||||
being traced through, because we don't actually know which source
|
||||
files to check at this point (before dynamo runs).
|
||||
On loading we will actually look at the source files being traced
|
||||
through. If any source file have changed (compared with the
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
from .caching import aot_compile_hash_factors
|
||||
|
||||
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
"torch_aot_compile",
|
||||
hash_key,
|
||||
)
|
||||
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_index
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
if isinstance(e, EOFError):
|
||||
message = "Compile cache file corrupted."
|
||||
else:
|
||||
message = str(e)
|
||||
logger.warning(
|
||||
"Compiling model again due to a load failure from %s, "
|
||||
"reason: %s",
|
||||
aot_compilation_path,
|
||||
message,
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
not envs.VLLM_USE_AOT_COMPILE
|
||||
or self.vllm_config.compilation_config.backend == "eager"
|
||||
)
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# This is the path for the first compilation.
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
ds_type,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.compilation_config.traced_files.add(original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
def patched_inline_call(self_: Any) -> Any:
|
||||
code = self_.f_code
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# Prepare backed_size_oblivious config patch if needed
|
||||
fx_config_patches = {}
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = (
|
||||
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
|
||||
if self.vllm_config.compilation_config.backend == "eager":
|
||||
logger.warning("Detected eager backend, disabling AOT compile.")
|
||||
use_aot_compile = False
|
||||
if use_aot_compile:
|
||||
from vllm.compilation.backends import set_on_compilation_complete
|
||||
|
||||
# store the path for saving after warmup
|
||||
self._aot_compilation_path = aot_compilation_path
|
||||
self._aot_cache_dir = cache_dir
|
||||
# set callback in context so it's available when compilation completes
|
||||
with set_on_compilation_complete(self.save_aot_compiled_function):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self: type[_T]) -> None:
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
|
||||
assert (
|
||||
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
|
||||
)
|
||||
|
||||
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
|
||||
try:
|
||||
os.makedirs(self._aot_cache_dir, exist_ok=True)
|
||||
# File saving should be atomic, so we will save to a temporary location
|
||||
# first. Should be upstreamed to PyTorch 2.12 as well.
|
||||
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
|
||||
self.aot_compiled_fn.save_compiled_function(tmp_file)
|
||||
os.replace(tmp_file, self._aot_compilation_path)
|
||||
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"unable to save AOT compiled function to %s: %s",
|
||||
self._aot_compilation_path,
|
||||
e,
|
||||
)
|
||||
|
||||
cls.__call__ = __call__
|
||||
cls.save_aot_compiled_function = save_aot_compiled_function
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(
|
||||
vllm_config: VllmConfig,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(
|
||||
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
|
||||
) -> Any:
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
),
|
||||
)
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
63
vllm/compilation/monitor.py
Normal file
63
vllm/compilation/monitor.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.perf_counter()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
total_compile_time: float = time.perf_counter() - torch_compile_start_time
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
"torch.compile takes %.2f s in total",
|
||||
total_compile_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled() -> None:
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError(
|
||||
"CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled."
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
75
vllm/compilation/partition_rules.py
Normal file
75
vllm/compilation/partition_rules.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(
|
||||
splitting_ops: list[str] | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
Inductor scheduler to partition the graph at these operators. The rules
|
||||
are automatically restored to their previous state on exit.
|
||||
|
||||
Args:
|
||||
splitting_ops: List of operator names to partition on.
|
||||
"""
|
||||
if not splitting_ops:
|
||||
logger.debug("No partition ops provided; skipping rule registration.")
|
||||
yield
|
||||
return
|
||||
|
||||
# Save current state before registering
|
||||
|
||||
saved_splitting_ops: list[str] = list(
|
||||
torch._inductor.config.custom_should_partition_ops
|
||||
)
|
||||
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
||||
|
||||
logger.debug(
|
||||
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clear and restore previous state
|
||||
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
||||
logger.debug("Restored previous partition rules state.")
|
||||
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
)
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
) -> None:
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, (
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
)
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kFp8StaticTensorSym)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
scale = self.quant_matcher.inputs()[1]
|
||||
return [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
result_quant = self.quant_matcher(result_silu_mul, scale)
|
||||
return result_quant[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
output_shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inps = self.get_inputs()
|
||||
pattern(*inps)
|
||||
|
||||
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kNvfp4Dynamic)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
result = self.empty_quant(5, 32)
|
||||
output_scale = empty_i32(128, 4)
|
||||
input_ = empty_bf16(5, 64)
|
||||
scale = empty_fp32(1, 1)
|
||||
return [result, output_scale, input_, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass"
|
||||
)
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
)
|
||||
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
@@ -0,0 +1,862 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as _flashinfer_comm
|
||||
|
||||
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
|
||||
_flashinfer_comm, "create_allreduce_fusion_workspace"
|
||||
):
|
||||
flashinfer_comm = _flashinfer_comm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 64, # 64MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 64, # 64MB
|
||||
4: 32, # 32MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer one shot fused allreduce
|
||||
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
|
||||
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 32, # 32MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 32, # 32MB
|
||||
4: 4, # 4MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
destroy_fi_ar_workspace,
|
||||
get_fi_ar_quant_workspace,
|
||||
get_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
initialize_fi_ar_workspace,
|
||||
)
|
||||
|
||||
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
|
||||
|
||||
MiB = 1024 * 1024
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_tensor_size = max_token_num * hidden_size * element_size
|
||||
assert current_tensor_size <= max_tensor_size, (
|
||||
f"Current tensor size {current_tensor_size} is larger than "
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
curr_device = current_platform.get_device_capability()
|
||||
device_capability = curr_device.to_int() if curr_device is not None else None
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, # type: ignore[arg-type, unused-ignore]
|
||||
{},
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
# Select workspace based on pattern: quant patterns use the
|
||||
# trtllm quant workspace, non-quant patterns use the primary workspace.
|
||||
if pattern_code in (
|
||||
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
|
||||
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
|
||||
):
|
||||
workspace = get_fi_ar_quant_workspace()
|
||||
else:
|
||||
workspace = get_fi_ar_workspace()
|
||||
assert workspace is not None, (
|
||||
"Flashinfer workspace must be initialized when using flashinfer"
|
||||
)
|
||||
assert flashinfer_comm is not None
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
|
||||
layout_code = None
|
||||
# layout_code only supported by trtllm backend
|
||||
if workspace.backend == "trtllm":
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=allreduce_in,
|
||||
workspace=workspace,
|
||||
pattern=pattern_code,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
output=None,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
residual_in=residual,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
scale_factor=scale_factor,
|
||||
layout_code=layout_code,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_trtllm_fused_allreduce_norm",
|
||||
op_func=call_trtllm_fused_allreduce_norm,
|
||||
mutates_args=[
|
||||
"allreduce_in",
|
||||
"residual",
|
||||
"norm_out",
|
||||
"quant_out",
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
)
|
||||
flashinfer_trtllm_fused_allreduce_norm = (
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFusedAllReduceParams:
|
||||
"""Parameters for FlashInfer fused allreduce operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
max_token_num: int = 1024,
|
||||
) -> None:
|
||||
self.world_size = world_size
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
|
||||
return {
|
||||
"world_size": self.world_size,
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
}
|
||||
|
||||
|
||||
# TODO(luka): unify
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=rms_result,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# rms_result, allreduce_in
|
||||
return allreduce[3], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# allreduce_in, residual
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return allreduce[4], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
|
||||
return quant, res
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual
|
||||
return allreduce[4], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return allreduce[4], allreduce[1], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], residual, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual, output_scale
|
||||
return allreduce[4], allreduce[2], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size <= 1:
|
||||
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
|
||||
return
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="all_reduce_fusion_pass"
|
||||
)
|
||||
if config.model_config is None:
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled for missing model_config."
|
||||
)
|
||||
return
|
||||
self.hidden_dim = config.model_config.get_hidden_size()
|
||||
self.group = get_tp_group().device_group
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
if flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"Flashinfer is not installed or comm module not found, "
|
||||
"skipping allreduce fusion pass"
|
||||
)
|
||||
return
|
||||
max_size = config.compilation_config.pass_config.flashinfer_max_size(
|
||||
self.tp_size
|
||||
)
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s"
|
||||
" or max size is not provided",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
|
||||
self.max_token_num = max_size // (self.hidden_dim * element_size)
|
||||
# take the min to save workspace size and we'll never use more
|
||||
# than max_num_batched_tokens anyways
|
||||
self.max_token_num = min(
|
||||
self.max_token_num, config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
|
||||
"Maximal number of tokens used by "
|
||||
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
for workspace_init_fn in [
|
||||
initialize_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
]:
|
||||
try:
|
||||
workspace_init_fn(
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
group=self.group,
|
||||
)
|
||||
except Exception as e:
|
||||
if "multicast" in str(e).lower():
|
||||
logger.warning(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"AllReduce fusion pass will be disabled.",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
self.allreduce_params = FlashInferFusedAllReduceParams(
|
||||
world_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
)
|
||||
|
||||
self.register_patterns()
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self) -> None:
|
||||
supports_quantization = get_fi_ar_quant_workspace() is not None
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
if supports_quantization:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
if self.disabled:
|
||||
logger.warning_once("AllReduce fusion pass is disabled.")
|
||||
return False
|
||||
return bool(compile_range.end <= self.max_token_num)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
if self.disabled:
|
||||
logger.debug("AllReduceFusionPass disabled")
|
||||
return
|
||||
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
with contextlib.suppress(Exception):
|
||||
destroy_fi_ar_workspace()
|
||||
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
P = ParamSpec("P")
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
|
||||
super().__init__(layer, kNvfp4Dynamic, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool) -> None:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
is_neox: bool,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
match_rocm_aiter: bool | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
if match_rocm_aiter is None:
|
||||
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
self.head_size = head_size
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.kv_size = self.num_kv_heads * self.head_size
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
elif match_rocm_aiter:
|
||||
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
positions = self.empty_int64(5)
|
||||
query = self.empty(5, self.q_size)
|
||||
key = self.empty(5, self.kv_size)
|
||||
cos_sin_cache = self.empty(4096, self.rotary_dim)
|
||||
return [positions, query, key, cos_sin_cache]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result = auto_functionalized(
|
||||
self.rotary_op,
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
)
|
||||
query_out = result[1]
|
||||
key_out = result[2] if len(result) > 2 else None
|
||||
return query_out, key_out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result: tuple[torch.Tensor, torch.Tensor | None] = (
|
||||
RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self._rmsnorm_op = RMS_OP
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self._rmsnorm_op(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight)
|
||||
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
self._rmsnorm_op = RMS_ADD_OP
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._rmsnorm_op( # type: ignore[no-any-return]
|
||||
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight, residual)
|
||||
|
||||
_, result, residual = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
match_rocm_aiter: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
|
||||
if match_rocm_aiter:
|
||||
assert not quant_key.scale.group_shape.is_per_tensor(), (
|
||||
"ROCm aiter fusion pass does not support per tensor quantization"
|
||||
)
|
||||
if quant_key.scale.group_shape.is_per_token():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
|
||||
else:
|
||||
assert quant_key.scale.group_shape.col == 128, (
|
||||
"ROCm aiter fusion pass currently supports "
|
||||
"quantization operation with group_size 128"
|
||||
)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
else:
|
||||
self.QUANT_OP = (
|
||||
torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
)
|
||||
|
||||
else:
|
||||
assert quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
|
||||
self.quant_fp8 = QuantFP8(
|
||||
quant_key.scale.static,
|
||||
quant_key.scale.group_shape,
|
||||
column_major_scales=has_col_major_scales,
|
||||
use_ue8m0=is_e8m0,
|
||||
tma_aligned_scales=self.is_tma_aligned,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||
return self.QUANT_OP( # type: ignore[no-any-return]
|
||||
x=input,
|
||||
quant_dtype=self.quant_key.dtype,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, scale)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
# for tma_aligned, the scale must be passed to forward_custom
|
||||
# tma_aligned fusion then matches by custom op arguments
|
||||
if not self.is_tma_aligned:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.has_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
input=input,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
|
||||
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
if transposed:
|
||||
scale_shape = tuple(reversed(scale_shape))
|
||||
return torch.empty(
|
||||
scale_shape, device=input.device, dtype=torch.float32
|
||||
).permute(-1, -2)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
|
||||
|
||||
class MatcherSiluAndMul(MatcherCustomOp):
|
||||
def __init__(self, enabled: bool | None = None) -> None:
|
||||
if enabled is None:
|
||||
enabled = SiluAndMul.enabled()
|
||||
super().__init__(enabled)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 4)
|
||||
return [input]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
|
||||
return result[1]
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return SiluAndMul.forward_native(x)
|
||||
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
Match the unfused sequence in attention blocks and replace with the fused op.
|
||||
|
||||
Unfused (conceptually):
|
||||
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
qh = reshape(q, [-1, num_heads, head_dim])
|
||||
kh = reshape(k, [-1, num_kv_heads, head_dim])
|
||||
qn = rms_norm(qh, q_weight, eps)
|
||||
kn = rms_norm(kh, k_weight, eps)
|
||||
qf = reshape(qn, [-1, num_heads * head_dim])
|
||||
kf = reshape(kn, [-1, num_kv_heads * head_dim])
|
||||
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
|
||||
return qf, kf, v
|
||||
|
||||
Fused replacement:
|
||||
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
|
||||
eps, q_weight, k_weight, cos_sin_cache, is_neox,
|
||||
positions.view(-1))
|
||||
return split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
rope_flashinfer: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=is_neox,
|
||||
head_size=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
positions = empty_i64(T)
|
||||
q_weight = empty_bf16(1, self.head_dim)
|
||||
k_weight = empty_bf16(1, self.head_dim)
|
||||
if self.rope_flashinfer:
|
||||
cos_sin_cache = empty_fp32(4096, self.head_dim)
|
||||
else:
|
||||
cos_sin_cache = empty_bf16(4096, self.head_dim)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q path: view -> RMS -> view back to q.shape
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
qkv=qkv,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=self.num_kv_heads,
|
||||
num_heads_v=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
eps=self.eps,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
position_ids=positions.view(-1),
|
||||
)
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
QkNormRopePattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
QkNormRopePattern.fx_view_to_reshape,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
)
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
|
||||
)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
config, Attention
|
||||
)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for neox in [True, False]:
|
||||
if RotaryEmbedding.enabled():
|
||||
for rope_flashinfer in [False, True]:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
rope_flashinfer=rope_flashinfer,
|
||||
).register(self.patterns)
|
||||
else:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
)
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, False
|
||||
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, True
|
||||
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, False
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
key: FusedRMSQuantKey,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=self.has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormGroupQuantPattern,
|
||||
)
|
||||
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .act_quant_fusion import ActivationQuantPattern
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AiterRMSNormQuantPattern:
|
||||
def __init__(
|
||||
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
match_rocm_aiter=match_aiter_quant,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1], result[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||
ops into an aiter rms_norm_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
# result, scale, residual
|
||||
return at[0], at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||
into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
|
||||
AiterRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
for match_aiter_quant in [True, False]:
|
||||
# Fuse aiter rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterFusedAddRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
]
|
||||
return self.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
ActivationQuantPattern,
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
]
|
||||
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AddAiterRMSNormPadPattern:
|
||||
"""
|
||||
This pattern replaces an aiter_rmsnorm_with_add & a pad op
|
||||
with a custom triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
hidden_size: int,
|
||||
x_pad_to_multiple: int,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.hidden_size = hidden_size
|
||||
self.x_pad_to_multiple = x_pad_to_multiple
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight, residual = self.rmsnorm_matcher.inputs()
|
||||
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
|
||||
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
|
||||
return [input, weight, residual, router_weight, router_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pad_size = self.x_pad_to_multiple - (
|
||||
self.hidden_size % self.x_pad_to_multiple
|
||||
)
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_rms, router_weight, router_bias
|
||||
)
|
||||
result = torch.nn.functional.pad(
|
||||
result_rms, (0, pad_size), mode="constant", value=0.0
|
||||
)
|
||||
return result, residual_out, router_logits
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
residual=residual,
|
||||
x_pad_to_multiple=self.x_pad_to_multiple,
|
||||
)
|
||||
result_padded = at[0]
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_padded[:, : self.hidden_size], router_weight, router_bias
|
||||
)
|
||||
residual_out = at[1]
|
||||
return result_padded, residual_out, router_logits
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass replaces an AITER CK RMSNorm + residual add and a pad op
|
||||
with an triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
|
||||
)
|
||||
|
||||
# gpt-oss has hidden size 2880
|
||||
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
|
||||
hidden_size = 2880
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for x_pad_to_multiple in [128, 256]:
|
||||
AddAiterRMSNormPadPattern(
|
||||
epsilon, hidden_size, x_pad_to_multiple
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
|
||||
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
Attention,
|
||||
get_attention_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherRotaryEmbedding,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
empty_bf16,
|
||||
empty_i64,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This impl fetches the KV cache and slot mapping from the forward context,
|
||||
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
||||
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
||||
that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
attn_layer.impl.do_rope_and_kv_cache_update(
|
||||
attn_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_rope_and_unified_kv_cache_update",
|
||||
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
||||
mutates_args=["query", "key"],
|
||||
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
||||
)
|
||||
|
||||
|
||||
class RopeReshapeKVCachePattern:
|
||||
"""
|
||||
This pattern matches the following unfused inplace ops:
|
||||
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
||||
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
||||
|
||||
and replaces it with the fused inplace op:
|
||||
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
||||
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
||||
)
|
||||
"""
|
||||
|
||||
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.num_kv_heads = layer.num_kv_heads
|
||||
self.head_size = layer.head_size
|
||||
self.head_size_v = layer.head_size_v
|
||||
self.is_neox = is_neox
|
||||
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.k_size = self.num_kv_heads * self.head_size
|
||||
self.v_size = self.num_kv_heads * self.head_size_v
|
||||
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=self.is_neox,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
L = 4096
|
||||
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
||||
positions = empty_i64(T)
|
||||
cos_sin_cache = empty_bf16(L, self.head_size)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
||||
return dummy, q, k, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
results = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
positions=positions,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
layer_name=self.layer_name,
|
||||
)
|
||||
return results[0], results[1], results[2], v
|
||||
|
||||
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
||||
gm = pm.fwd_only(*args, **kwargs)
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses the rotary embedding and KV cache update operations
|
||||
into a single fused kernel if available.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
This fusion eliminates the need for separate kernel launches and
|
||||
intermediate memory operations between the RoPE and cache update steps.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rope_kv_cache_fusion_pass"
|
||||
)
|
||||
|
||||
cc = config.compilation_config
|
||||
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for _, layer in attn_layers.items():
|
||||
if layer.impl.fused_rope_kvcache_supported():
|
||||
for is_neox in [True, False]:
|
||||
RopeReshapeKVCachePattern(
|
||||
layer=layer,
|
||||
is_neox=is_neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass works best for the small-batch decode setting.
|
||||
# For large-batch e.g. prefill, it is better to use two separate kernels
|
||||
# since they are compute bound and the fused kernels require further tuning.
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|
||||
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Min hidden size per device capability for sequence parallelism
|
||||
# Only apply sequence parallelism for models with hidden_size >= threshold
|
||||
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
|
||||
90: 8192, # H100: only for models with hidden_size >= 8192
|
||||
}
|
||||
|
||||
# Min size per GPU per device capability for sequence parallelism
|
||||
# Total min size = min_per_gpu_size * tp_size
|
||||
# This ensures the threshold scales appropriately with tensor parallelism
|
||||
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
|
||||
90: 8, # 8MB per GPU for H100
|
||||
}
|
||||
|
||||
|
||||
def get_sequence_parallelism_threshold(
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
element_size: int,
|
||||
) -> int | None:
|
||||
"""
|
||||
Calculate the minimum token threshold for applying sequence parallelism.
|
||||
|
||||
Returns None if sequence parallelism should not be applied based on model size.
|
||||
|
||||
Branching logic based on device capability:
|
||||
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
|
||||
- If not, returns None (SP disabled for small models on this device)
|
||||
- If yes, calculates threshold based on per-GPU size
|
||||
|
||||
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
|
||||
(hidden_size * element_size)
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return None
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
return None
|
||||
device_capability = capability.to_int()
|
||||
|
||||
# Check if device has configured thresholds
|
||||
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
|
||||
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
|
||||
|
||||
if min_hidden_size is None or min_per_gpu_size_mb is None:
|
||||
return None
|
||||
|
||||
# Only apply sequence parallelism for models meeting the size threshold
|
||||
if hidden_size < min_hidden_size:
|
||||
return None
|
||||
|
||||
MiB = 1024 * 1024
|
||||
min_size = min_per_gpu_size_mb * MiB * tp_size
|
||||
return int(min_size // (hidden_size * element_size))
|
||||
|
||||
|
||||
def get_first_out_wrapper(
|
||||
fn: Callable[..., Sequence[torch.Tensor]],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any) -> torch.Tensor:
|
||||
return fn(*args)[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Get min_token_num threshold
|
||||
# Read min_token_num from config (calculated during config init)
|
||||
self.min_token_num = None
|
||||
if config.model_config is not None:
|
||||
pass_config = config.compilation_config.pass_config
|
||||
self.min_token_num = pass_config.sp_min_token_num
|
||||
|
||||
if self.min_token_num is not None:
|
||||
# Take the min to avoid exceeding max_num_batched_tokens
|
||||
max_batched = config.scheduler_config.max_num_batched_tokens
|
||||
if max_batched is not None:
|
||||
self.min_token_num = min(self.min_token_num, max_batched)
|
||||
logger.debug_once(
|
||||
f"Sequence parallelism min token threshold: {self.min_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
# Used to clean up redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Determines if sequence parallelism should be applied for the given
|
||||
compile range.
|
||||
|
||||
SP is only beneficial for larger batch sizes where the communication
|
||||
overhead is amortized. For small batches, the overhead of splitting
|
||||
and gathering tensors across TP ranks outweighs the benefits.
|
||||
|
||||
Returns False (SP disabled) when:
|
||||
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||
- min_token_num is None (SP disabled for this device/config)
|
||||
- The compile range starts below the minimum token threshold
|
||||
"""
|
||||
# For piecewise compilation (not using inductor graph partition),
|
||||
# we need concrete sizes that are divisible by TP for correct splitting
|
||||
if (
|
||||
not self.compilation_config.use_inductor_graph_partition
|
||||
and self.compilation_config.splitting_ops
|
||||
):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||
return False
|
||||
|
||||
# min_token_num is None when SP is disabled for this device/config
|
||||
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||
if self.min_token_num is None:
|
||||
return False
|
||||
|
||||
# Only apply SP when batch size meets the minimum threshold
|
||||
return compile_range.start >= self.min_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
77
vllm/compilation/passes/fx_utils.py
Normal file
77
vllm/compilation/passes/fx_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx.node import Target
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target: Target) -> bool:
|
||||
return bool(node.op == "call_function" and node.target == target)
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
134
vllm/compilation/passes/inductor_pass.py
Normal file
134
vllm/compilation/passes/inductor_pass.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, compile_range: Range):
|
||||
self.compile_range: Range = compile_range
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(compile_range: Range) -> Generator[None, None, None]:
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(compile_range)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]) -> str:
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
|
||||
) -> None:
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
178
vllm/compilation/passes/pass_manager.py
Normal file
178
vllm/compilation/passes/pass_manager.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from .fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.scatter_split_replace import ScatterSplitReplacementPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .utility.fix_functionalization import FixFunctionalizationPass
|
||||
from .utility.noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig) -> None:
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.eliminate_noops:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sp:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.fuse_gemm_comms:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.fuse_allreduce_rms:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [
|
||||
RocmAiterRMSNormQuantFusionPass(config),
|
||||
]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [ScatterSplitReplacementPass(config)]
|
||||
self.passes += [RopeKVCacheFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass) -> None:
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
passes = []
|
||||
|
||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||
for pass_ in self.passes:
|
||||
passes.append(pass_.uuid())
|
||||
passes.append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
state["passes"] = passes
|
||||
return InductorPass.hash_dict(state)
|
||||
0
vllm/compilation/passes/utility/__init__.py
Normal file
0
vllm/compilation/passes/utility/__init__.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
# just use mm_node directly.
|
||||
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
else:
|
||||
# Directly replace the auto_functionalize(rotary_embedding)
|
||||
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
|
||||
and at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "allreduce_in",
|
||||
2: "residual",
|
||||
3: "norm_out",
|
||||
4: "quant_out",
|
||||
5: "scale_out",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
|
||||
and at_target
|
||||
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "query",
|
||||
2: "key",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
|
||||
and at_target
|
||||
== torch.ops.vllm.function_with_mutated_args_and_return.default
|
||||
):
|
||||
mutated_args = {1: "x"}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, torch.fx.Node | str],
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
) -> None:
|
||||
"""
|
||||
Replace mutated getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
# Some functionalized nodes may return both a result at getitem[0]
|
||||
# as well as mutated args at getitem[1:...]
|
||||
if idx == 0:
|
||||
assert idx not in mutated_args, (
|
||||
f"result at getitem[0] should not be in mutated_args for {node}"
|
||||
)
|
||||
continue
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), (
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
)
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
fn_node = graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
fn_node = graph.call_function(function, args=args)
|
||||
|
||||
# If the function returns a value as well as mutating args inplace,
|
||||
# the functionalized node will have a getitem[0] user that holds this value
|
||||
# Replace getitem[0] user of the auto-functionalized node
|
||||
# with the new defunctionalized node directly if it exists
|
||||
users = self.getitem_users(node)
|
||||
if 0 in users:
|
||||
user = users[0]
|
||||
user.replace_all_uses_with(fn_node)
|
||||
self._remove(user)
|
||||
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# remove reshape/slice if it produces the original shape
|
||||
if is_func(node, torch.ops.aten.reshape.default) or is_func(
|
||||
node, torch.ops.aten.slice.Tensor
|
||||
):
|
||||
input = node.args[0]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
if self.all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if self.all_dims_equivalent(base_shape, view_shape):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Shape comparison helpers ----------------------
|
||||
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||
) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
# Different ranks can't be equivalent
|
||||
return False
|
||||
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
||||
assignment if there are no users for the inplace tensor written to by
|
||||
the slice_scatter call.
|
||||
|
||||
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
||||
that are split+getitem outputs of a single qkv tensor.
|
||||
When functionalized, we fetch the rotated query and key from the functionalized op
|
||||
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
||||
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
||||
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
||||
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
||||
|
||||
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
||||
writing a custom pass for it before defunctionalization allows matching against the
|
||||
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
||||
the inplace subtensor written to by the slice_scatter has no other users.
|
||||
|
||||
Here's an example graph with q_size = 512, kv_size = 64:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
||||
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
||||
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
q = operator.getitem(split_with_sizes_2, 0)
|
||||
k = operator.getitem(split_with_sizes_2, 1)
|
||||
v = operator.getitem(split_with_sizes_2, 2)
|
||||
|
||||
After this pass, this sequence of nodes is replaced with:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
v = operator.getitem(split_with_sizes_1, 2)
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
target_ops = [torch.ops._C.rotary_embedding.default]
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target in target_ops:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
getitem_nodes[user.args[1]] = user
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of a qkv tensor.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into qkv, then split again to get query and key.
|
||||
# If the inplace tensor has no other users, we can replace
|
||||
# the slice_scatter+split nodes with the original results.
|
||||
for user in getitem_nodes[1].users:
|
||||
slice_scatter_1_node = user
|
||||
if not is_func(
|
||||
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in getitem_nodes[2].users:
|
||||
slice_scatter_2_node = user
|
||||
if not is_func(
|
||||
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in slice_scatter_2_node.users:
|
||||
split_node = user
|
||||
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
|
||||
split_getitem_users = {}
|
||||
for user in split_node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
split_getitem_users[user.args[1]] = user
|
||||
|
||||
# Replace query node
|
||||
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
||||
graph.erase_node(split_getitem_users[0])
|
||||
# Replace key node
|
||||
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
||||
graph.erase_node(split_getitem_users[1])
|
||||
# Redirect value node to original qkv tensor
|
||||
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
||||
|
||||
# Erase unused nodes
|
||||
graph.erase_node(split_node)
|
||||
graph.erase_node(slice_scatter_2_node)
|
||||
graph.erase_node(slice_scatter_1_node)
|
||||
|
||||
count += 1
|
||||
|
||||
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|
||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||
input tensor with the same split sizes.
|
||||
|
||||
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||
see a single split node with all users attached.
|
||||
|
||||
See also:
|
||||
- vLLM #33295 (original issue)
|
||||
- PyTorch #174472 (upstream CSE gap)
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SplitCoalescingPass(VllmInductorPass):
|
||||
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||
node when they share the same input tensor and split sizes."""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
# Map from input tensor node -> list of split nodes seen so far.
|
||||
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||
continue
|
||||
|
||||
arg_node, split_sizes = node.args[:2]
|
||||
|
||||
if arg_node not in split_nodes:
|
||||
split_nodes[arg_node] = [node]
|
||||
continue
|
||||
|
||||
# Find existing node with same split_sizes
|
||||
canonical = next(
|
||||
(
|
||||
n
|
||||
for n in split_nodes[arg_node]
|
||||
if list(n.args[1]) == list(split_sizes)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if canonical is not None:
|
||||
node.replace_all_uses_with(canonical)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
else:
|
||||
split_nodes[arg_node].append(node)
|
||||
|
||||
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InductorCompilationConfig:
|
||||
splitting_ops: list[str] | None = None
|
||||
use_inductor_graph_partition: bool = False
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
dump_prefix: ClassVar[int | None] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||
self.compilation_config = InductorCompilationConfig(
|
||||
splitting_ops=config.compilation_config.splitting_ops,
|
||||
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device: str | None = (
|
||||
config.device_config.device if config.device_config else None
|
||||
)
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(
|
||||
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
|
||||
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self) -> None:
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self) -> None:
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
|
||||
)
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return str(
|
||||
self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compile_debug_dump_path()
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils.system_utils import unique_filepath
|
||||
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
|
||||
)
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f"# This file was produced by VllmPatternMatcherPass."
|
||||
f"dump_patterns for {self.pass_name}.\n"
|
||||
f"# It does its best to produce valid-Python-looking code but"
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join(
|
||||
[f"def pattern_{i}():"]
|
||||
+ [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
]
|
||||
+ [f"return {out_node}"]
|
||||
).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
def __init__(self, name: str, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.dump_graph(graph, self.name)
|
||||
343
vllm/compilation/piecewise_backend.py
Normal file
343
vllm/compilation/piecewise_backend.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
from typing import Any
|
||||
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable[..., Any] = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule | None,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
submod_name: str = "",
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
This class supports two mutually exclusive modes:
|
||||
1. Compilation (graph is set, compiled_runnables is None):
|
||||
Used during initial compilation when we have the FX graph
|
||||
and need to compile it for each shape range.
|
||||
2. Precompilation (graph is None, compiled_runnables is set):
|
||||
Used when loading from cache/AOT artifacts where we already
|
||||
have pre-compiled callables and don't need the original graph.
|
||||
|
||||
Exactly one of graph or compiled_runnables must be provided.
|
||||
"""
|
||||
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
|
||||
"exactly one of graph and compiled_runnables should be set."
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
self.submod_name = submod_name
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
self.is_encoder_compilation = vllm_backend.is_encoder
|
||||
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
if self.is_encoder_compilation:
|
||||
# For encoder compilation we use the max int32 value
|
||||
# to set the upper bound of the compile ranges
|
||||
max_int32 = 2**31 - 1
|
||||
last_compile_range = self.compile_ranges[-1]
|
||||
assert (
|
||||
last_compile_range.end
|
||||
== vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.compile_ranges[-1] = Range(
|
||||
start=last_compile_range.start, end=max_int32
|
||||
)
|
||||
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
self.returns_tuple = returns_tuple
|
||||
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
if self.compile_sizes is not None:
|
||||
for size in self.compile_sizes:
|
||||
if isinstance(size, str):
|
||||
assert size == "cudagraph_capture_sizes"
|
||||
raise NotImplementedError(
|
||||
"cudagraph_capture_sizes not supported in compile_sizes."
|
||||
"This should be handled in `post_init_cudagraph_sizes`."
|
||||
)
|
||||
else:
|
||||
assert isinstance(size, int)
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# Track whether we've logged the graph for this subgraph (only log once)
|
||||
self._graph_logged = False
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
from vllm.compilation.backends import _on_compilation_complete_callback
|
||||
|
||||
self.on_compilation_complete = _on_compilation_complete_callback.get()
|
||||
|
||||
def get_compiled_graph_wrapper(
|
||||
self, compiled_graph: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
def compiled_graph_wrapper(*args: Any) -> Any:
|
||||
graph_output = compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
def check_for_ending_compilation(self) -> None:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
time_before_saving = time.perf_counter()
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
elapsed = time.perf_counter() - time_before_saving
|
||||
if elapsed > 1:
|
||||
logger.info_once(
|
||||
"Saved compiler manager cache in %.2f seconds.",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
# Call the completion callback (e.g., to save AOT compiled function)
|
||||
if self.on_compilation_complete is not None:
|
||||
self.on_compilation_complete()
|
||||
|
||||
def to_bytes(self) -> dict[str, bytes]:
|
||||
class StandaloneCompiledArtifactsPickler(Pickler):
|
||||
def reducer_override(self, obj: object) -> Any:
|
||||
if isinstance(obj, CachingAutotuner):
|
||||
obj.prepare_for_pickle()
|
||||
return pickle.loads, (
|
||||
pickle.dumps(
|
||||
obj,
|
||||
),
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def serialize(fn: Callable[..., Any]) -> bytes:
|
||||
assert hasattr(fn, "serialize"), "fn must have serialize method"
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
entry = fn.serialize()
|
||||
|
||||
f = io.BytesIO()
|
||||
StandaloneCompiledArtifactsPickler(f).dump(entry)
|
||||
result = f.getvalue()
|
||||
return result
|
||||
|
||||
out = {}
|
||||
|
||||
for range_key, entry in self.range_entries.items():
|
||||
if not entry.compiled:
|
||||
logger.debug(
|
||||
"entry with range %s not compiled, so cannot get its bytes",
|
||||
range_key,
|
||||
)
|
||||
continue
|
||||
if hasattr(entry.runnable, "serialize"):
|
||||
out[str(range_key)] = serialize(entry.runnable)
|
||||
|
||||
return out
|
||||
|
||||
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
# matching their values.
|
||||
# This is problem because it can lead to unintended specializations!
|
||||
# if the new wrongly dynamic dim is specialized
|
||||
# it will force specializing the whole shape
|
||||
# torch.compile probably should not accept
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
assert self.graph is not None
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
fake_example_inputs.append(node.meta["example_value"])
|
||||
else:
|
||||
break
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _log_compile_start(self, compile_range: Range):
|
||||
"""Log compilation event for TORCH_TRACE/tlparse."""
|
||||
is_cudagraph_size = (
|
||||
self.compile_sizes is not None and compile_range.start in self.compile_sizes
|
||||
)
|
||||
subgraph_index = self.piecewise_compile_index
|
||||
submod_name = self.submod_name
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_piecewise_compile_start",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"piecewise_index": subgraph_index,
|
||||
"submod_name": submod_name,
|
||||
"total_piecewise_compiles": self.total_piecewise_compiles,
|
||||
"compile_range_start": compile_range.start,
|
||||
"compile_range_end": compile_range.end,
|
||||
"is_single_size": compile_range.is_single_size(),
|
||||
"is_cudagraph_capture_size": is_cudagraph_size,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Log the subgraph graph dump only once per subgraph (not per size)
|
||||
# to reduce log file size. The graph code is the same for all sizes.
|
||||
if not self._graph_logged:
|
||||
self._graph_logged = True
|
||||
assert self.graph is not None
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"vllm_{submod_name}",
|
||||
},
|
||||
payload_fn=lambda: self.graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
if not range_entry.compiled:
|
||||
if self.compiled_runnables is not None:
|
||||
range_entry.runnable = self.get_compiled_graph_wrapper(
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
self._log_compile_start(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if self.compile_sizes is None:
|
||||
return None
|
||||
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
|
||||
)
|
||||
|
||||
self._maybe_compile_for_range_entry(range_entry, args)
|
||||
return range_entry.runnable(*args)
|
||||
321
vllm/compilation/wrapper.py
Normal file
321
vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _noop_add_global_state_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context() -> Generator[None, None, None]:
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _call_with_optional_nvtx_range(
|
||||
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Any:
|
||||
if self.layerwise_nvtx_tracing_enabled:
|
||||
args_list = list(args)
|
||||
kwargs_dict = dict(kwargs)
|
||||
with layerwise_nvtx_marker_context(
|
||||
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
|
||||
self,
|
||||
in_tensor=args_list,
|
||||
kwargs=kwargs_dict,
|
||||
) as ctx:
|
||||
ctx.result = callable_fn(*args, **kwargs)
|
||||
return ctx.result
|
||||
return callable_fn(*args, **kwargs)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
mode = vllm_config.compilation_config.mode
|
||||
self.layerwise_nvtx_tracing_enabled = (
|
||||
vllm_config.observability_config.enable_layerwise_nvtx_tracing
|
||||
)
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
self.first_compile = True
|
||||
self.evaluate_guards = (
|
||||
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
)
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
if self.evaluate_guards:
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
|
||||
options["guard_filter_fn"] = lambda x: [
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation. And if we use
|
||||
# compiled_ptr = self.check_invariants_and_forward
|
||||
# it will reset all entries.
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode: CodeType | None = None
|
||||
|
||||
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
"Please make sure torch.compile is enabled with the latest "
|
||||
f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if self.first_compile or not self.evaluate_guards
|
||||
else torch.compiler.set_stance("fail_on_recompile")
|
||||
)
|
||||
self.first_compile = False
|
||||
with _compilation_context(), ctx:
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
decompiled_file = path / "transformed_code.py"
|
||||
if not decompiled_file.exists():
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and "update" in new_code.co_names
|
||||
):
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
msg = (
|
||||
"Assigning / modifying buffers of nn.Module during forward pass is not "
|
||||
"allowed when using cudagraph inside the compiler because it will "
|
||||
"cause silent errors. Please use eager mode or fix the code. The "
|
||||
"following code contains clues about which buffer is being modified "
|
||||
f"(please search for the usage of the function `update`):\n{src}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
Reference in New Issue
Block a user