init
This commit is contained in:
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
BIN
vllm/compilation/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/compilation/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
BIN
vllm/compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
Binary file not shown.
189
vllm/compilation/activation_quant_fusion.py
Normal file
189
vllm/compilation/activation_quant_fusion.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
register_replacement)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"))
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, \
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric: bool = True):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(quant_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
scale=scale)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 4), # result
|
||||
empty_bf16(5, 4), # result_silu_mul
|
||||
empty_bf16(5, 4), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=result,
|
||||
input=at1[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=scale)
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 32), # result
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # result_silu_mul
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass")
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern)
|
||||
650
vllm/compilation/backends.py
Normal file
650
vllm/compilation/backends.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
"""
|
||||
A manager to manage the compilation process, including
|
||||
caching the compiled graph, loading the compiled graph,
|
||||
and compiling the graph.
|
||||
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
When serializing the cache, we save it to a Python file
|
||||
for readability. We don't use json here because json doesn't
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
The organization of the cache directory is as follows:
|
||||
cache_dir=/path/to/hash_str/rank_i_j/prefix/
|
||||
inside cache_dir, there will be:
|
||||
- vllm_compile_cache.py
|
||||
- computation_graph.py
|
||||
- transformed_code.py
|
||||
|
||||
for multiple prefixes, they can share the same
|
||||
base cache dir of /path/to/hash_str/rank_i_j/ ,
|
||||
to store some common compilation artifacts.
|
||||
"""
|
||||
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
with open(self.cache_file_path) as f:
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||
disable_cache=disable_cache,
|
||||
prefix=prefix)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
return
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
data = printer.pformat(self.cache)
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(handle, graph, example_inputs,
|
||||
graph_index, runtime_shape)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via "
|
||||
"handle %s", graph_index, str(runtime_shape),
|
||||
self.compiler.name, handle)
|
||||
return compiled_graph
|
||||
|
||||
def compile(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index,
|
||||
runtime_shape)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
# there can be multiple graphs due to piecewise compilation.
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s", elapsed)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s", str(runtime_shape),
|
||||
elapsed)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
if isinstance(self.compiler, InductorAdaptor):
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = \
|
||||
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape,
|
||||
maybe_key)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index,
|
||||
self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Cache the graph for dynamic shape for later use")
|
||||
else:
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index, str(runtime_shape), self.compiler.name,
|
||||
handle)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed)
|
||||
else:
|
||||
logger.info("Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape, elapsed)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(graph: fx.GraphModule,
|
||||
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
split_op_graphs = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == 'call_function' and str(node.target) in ops:
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph,
|
||||
None,
|
||||
lambda node: node_to_subgraph_id[node],
|
||||
keep_original_order=True)
|
||||
|
||||
outputs = []
|
||||
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
# recursive child module or the root module
|
||||
continue
|
||||
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(
|
||||
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by integer graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend"):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument,
|
||||
...], kwargs: dict[str, Any]) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = self.vllm_backend.\
|
||||
compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
# Lazy import here to avoid circular import
|
||||
from .cuda_piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph))
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# the tag for the part of model being compiled,
|
||||
# e.g. backbone/eagle_head
|
||||
model_tag: str = "backbone"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert tag != model_tag, \
|
||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model_tag = old_tag
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with vLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
where we customize the compilation.
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also adds the PostGradPassManager to Inductor config,
|
||||
which handles the post-grad passes.
|
||||
"""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
compilation_config: CompilationConfig
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. language_model, vision_model, etc.
|
||||
# when multiple parts are initialized as independent
|
||||
# models, we need to use the model_tag to distinguish
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (inductor_config[PASS_KEY].uuid() ==
|
||||
self.post_grad_pass_manager.uuid())
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = envs.compute_hash()
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
forward_code_files = list(
|
||||
sorted(self.compilation_config.traced_files))
|
||||
self.compilation_config.traced_files.clear()
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
"\n".join(forward_code_files))
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
import hashlib
|
||||
code_hash = hashlib.md5("\n".join(hash_content).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
hash_key,
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
|
||||
self.prefix)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
|
||||
if disable_cache:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir)
|
||||
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
|
||||
self.prefix)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
from .monitor import torch_compile_start_time
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
# called once
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_config.splitting_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
# depyf will hook lazy_format_graph_code and dump the graph
|
||||
# for debugging, no need to print the graph here
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.vllm_config,
|
||||
self).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||
# use `print_readable` because it can include submodules
|
||||
src = "from __future__ import annotations\nimport torch\n" + \
|
||||
self.split_gm.print_readable(print_output=False)
|
||||
src = src.replace("<lambda>", "GraphModule")
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Computation graph saved to %s", graph_path)
|
||||
|
||||
self._called = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
|
||||
not self.compilation_config.cudagraph_copy_inputs:
|
||||
return self.split_gm
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
self.sym_tensor_indices = [
|
||||
i for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
|
||||
any(is_symbolic(d) for d in x.size())
|
||||
]
|
||||
|
||||
# compiler managed cudagraph input buffers
|
||||
# we assume the first run with symbolic shapes
|
||||
# has the maximum size among all the tensors
|
||||
self.input_buffers = [
|
||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||
]
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def copy_and_call(*args):
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||
|
||||
# copy the tensor to the static buffer
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
|
||||
# replace the tensor in the list_args to the static buffer
|
||||
list_args[index] = static_tensor
|
||||
return self.split_gm(*list_args)
|
||||
|
||||
return copy_and_call
|
||||
56
vllm/compilation/base_static_graph.py
Normal file
56
vllm/compilation/base_static_graph.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
1188
vllm/compilation/collective_fusion.py
Normal file
1188
vllm/compilation/collective_fusion.py
Normal file
File diff suppressed because it is too large
Load Diff
573
vllm/compilation/compiler_interface.py
Normal file
573
vllm/compilation/compiler_interface.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from .inductor_pass import pass_context
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
name = "inductor_standalone"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config})
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
compiled_graph.save(path=path, format="unpacked")
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked")
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import (FxGraphCache,
|
||||
compiled_fx_graph_hash)
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(
|
||||
*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash))
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False))
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False))
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(
|
||||
enable_remote_autograd_cache=False))
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
|
||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
# compilation cache.
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. ")
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph")
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import (
|
||||
CompiledFxGraphConstantsWithGm)
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different level of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
47
vllm/compilation/counter.py
Normal file
47
vllm/compilation/counter.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||
dynamo_as_is_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}")
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
199
vllm/compilation/cuda_graph.py
Normal file
199
vllm/compilation/cuda_graph.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
cudagraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = \
|
||||
CUDAGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
117
vllm/compilation/cuda_piecewise_backend.py
Normal file
117
vllm/compilation/cuda_piecewise_backend.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
return entry.runnable(*args)
|
||||
400
vllm/compilation/decorators.py
Normal file
400
vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
from typing import Callable, Optional, TypeVar, Union, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: _T) -> _T:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T:
|
||||
...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: Optional[_T] = None,
|
||||
*,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> Union[Callable[[_T], _T], _T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor, Optional[torch.Tensor],
|
||||
IntermediateTensors, Optional[IntermediateTensors]
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(("Inferred dynamic dimensions for "
|
||||
"forward method of %s: %s"), cls,
|
||||
list(inferred_dynamic_arg_dims.keys()))
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}")
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
|
||||
enable_if)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = \
|
||||
vllm_config.compilation_config.level in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo() or _should_ignore_torch_compile(
|
||||
self.__class__) or not enable_compile
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
arg.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim
|
||||
for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s",
|
||||
self.original_code_object)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
self.original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call
|
||||
|
||||
def patched_inline_call(parent, func, args, kwargs):
|
||||
code = func.get_code()
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
code.co_filename)
|
||||
return inline_call(parent, func, args, kwargs)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches[
|
||||
"enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug(
|
||||
"enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with patch.object(
|
||||
InliningInstructionTranslator, "inline_call",
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
def customized_cudagraph_wrapper(f,
|
||||
metadata: CUDAGraphWrapperMetadata):
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
))
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper)
|
||||
|
||||
yield
|
||||
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(
|
||||
torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||
RowvLLMParameter
|
||||
]),
|
||||
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||
return_false)):
|
||||
yield
|
||||
205
vllm/compilation/fix_functionalization.py
Normal file
205
vllm/compilation/fix_functionalization.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug("XPU platform does not support fix functionalization"
|
||||
"pass currently.")
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs['query']
|
||||
mm_node = query.args[0].args[0]
|
||||
|
||||
# rotary_embedding is a special case: the two mutating inputs
|
||||
# are query and key, which are slices of mm_node.
|
||||
# While functionalized, results at[1] and at[2] are scattered
|
||||
# back into mm_node. After de-functionalization, we can just
|
||||
# use mm_node directly.
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(user_of_getitem,
|
||||
torch.ops.aten.slice_scatter.default):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: 'input', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input'))
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input', 'scale'))
|
||||
elif hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
|
||||
mutated_args = {1: 'result', 2: 'result_block_scale'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'result_block_scale',
|
||||
'input', 'input_global_scale'))
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||
count_removed)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||
Iterable[torch.fx.Node]]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||
mutated_args: dict[int,
|
||||
Union[torch.fx.Node,
|
||||
str]]):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), \
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
for arg in args)
|
||||
graph.call_function(function, args=args)
|
||||
383
vllm/compilation/fusion.py
Normal file
383
vllm/compilation/fusion.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self):
|
||||
return (f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)")
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
|
||||
assert key in FUSED_OPS, \
|
||||
f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
|
||||
pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(self, RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern)
|
||||
295
vllm/compilation/fusion_attn.py
Normal file
295
vllm/compilation/fusion_attn.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # v
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # attn_output
|
||||
self.empty_quant(5,
|
||||
self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(
|
||||
output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view)
|
||||
output = RESHAPE_OP(at2[1],
|
||||
[-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered.")
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern)
|
||||
84
vllm/compilation/fx_utils.py
Normal file
84
vllm/compilation/fx_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Optional
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first specified node with the given op
|
||||
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_specified_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
136
vllm/compilation/inductor_pass.py
Normal file
136
vllm/compilation/inductor_pass.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import ( # noqa: E501
|
||||
Torch25CustomGraphPass as CustomGraphPass)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
|
||||
def __init__(self, runtime_shape: Optional[int]):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: Optional[int]):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: Union[str, Any]):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
with torch._guards.tracing(
|
||||
None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
57
vllm/compilation/monitor.py
Normal file
57
vllm/compilation/monitor.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||
compilation_config.debug_dump_path:
|
||||
import depyf
|
||||
path = os.path.join(compilation_config.debug_dump_path,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path)
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
logger.info("torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled():
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled.")
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
158
vllm/compilation/noop_elimination.py
Normal file
158
vllm/compilation/noop_elimination.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# Case 2: remove this reshape if it produces the original shape
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if self.reshape_all_dims_equivalent(shape, input_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||
# python slicing semantics are different from reshape
|
||||
# Don't treat -1 as inferred dimension
|
||||
input, dim_index, start, end = node.args[:4]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
|
||||
if output_shape == input_shape:
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if base_shape == view_shape:
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Reshape helpers ----------------------
|
||||
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
||||
|
||||
def reshape_all_dims_equivalent(
|
||||
self,
|
||||
dims: Iterable[Union[int, torch.fx.Node]],
|
||||
i_dims: Iterable[Union[int, SymInt]],
|
||||
) -> bool:
|
||||
return all(
|
||||
self.reshape_dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(dims, i_dims))
|
||||
125
vllm/compilation/pass_manager.py
Normal file
125
vllm/compilation/pass_manager.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn):
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
return InductorPass.hash_dict(state)
|
||||
20
vllm/compilation/post_cleanup.py
Normal file
20
vllm/compilation/post_cleanup.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
478
vllm/compilation/sequence_parallelism.py
Normal file
478
vllm/compilation/sequence_parallelism.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.quant_op = quant_op
|
||||
|
||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
|
||||
weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
|
||||
quant_result_buffer, input_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
|
||||
input_tensor, residual_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, permute, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
|
||||
arg3_1)
|
||||
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale)
|
||||
return static_fp8[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter,
|
||||
dtype=rmsnorm_result.dtype)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
42
vllm/compilation/torch25_custom_graph_pass.py
Normal file
42
vllm/compilation/torch25_custom_graph_pass.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
|
||||
"""
|
||||
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
|
||||
It conforms to the 2.6 interface but also supports pickling, as that's what
|
||||
the inductor code cache uses to determine the cache key before 2.6.
|
||||
(in 2.6 and above, uuid() is used.)
|
||||
|
||||
Subclasses can just "pretend" that uuid is used.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
"""
|
||||
Implementation of the custom pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Optional[Any]:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom pass implementation.
|
||||
Return None to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
|
||||
to enable subclasses to only have to implement uuid.
|
||||
"""
|
||||
return self.uuid()
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
" native uuid support for custom passes.")
|
||||
156
vllm/compilation/vllm_inductor_pass.py
Normal file
156
vllm/compilation/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
dump_prefix: ClassVar[Optional[int]] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config \
|
||||
else None
|
||||
self.device = config.device_config.device if config.device_config \
|
||||
else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
||||
graph.owning_module)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compilation_config.debug_dump_path
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
rank = config.parallel_config.rank
|
||||
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils import unique_filepath
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f'# This file was produced by VllmPatternMatcherPass.'
|
||||
f'dump_patterns for {self.pass_name}.\n'
|
||||
f'# It does its best to produce valid-Python-looking code but'
|
||||
f' please add to dump_patterns if there are any errors.\n\n'
|
||||
f'from torch._higher_order_ops.auto_functionalize import '
|
||||
f'auto_functionalized as auto_functionalized\n'
|
||||
f'from torch._inductor.pattern_matcher import *',
|
||||
file=f)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
] + [f"return {out_node}"]).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, self.name)
|
||||
136
vllm/compilation/wrapper.py
Normal file
136
vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode,
|
||||
get_current_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = get_current_vllm_config(
|
||||
).compilation_config.inductor_compile_config
|
||||
|
||||
compiled_callable = torch.compile(self.forward,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
options=options)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: list[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
|
||||
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}",
|
||||
"transformed_code.py")
|
||||
if not os.path.exists(decompiled_file):
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s",
|
||||
decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != \
|
||||
CUDAGraphMode.NONE and "update" in new_code.co_names:
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
Reference in New Issue
Block a user