v1.0
This commit is contained in:
0
compilation/__init__.py
Normal file
0
compilation/__init__.py
Normal file
BIN
compilation/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/activation_quant_fusion.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/activation_quant_fusion.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/backends.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/backends.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/base_static_graph.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/base_static_graph.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/caching.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/caching.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/collective_fusion.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/collective_fusion.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/compiler_interface.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/compiler_interface.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/counter.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/counter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/cuda_graph.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/cuda_graph.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/decorators.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/decorators.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/fix_functionalization.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/fix_functionalization.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/fusion.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/fusion.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/fusion_attn.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/fusion_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/fx_utils.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/fx_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/inductor_pass.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/matcher_utils.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/matcher_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/monitor.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/monitor.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/noop_elimination.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/noop_elimination.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/partition_rules.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/partition_rules.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/pass_manager.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/pass_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/piecewise_backend.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/piecewise_backend.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/post_cleanup.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/post_cleanup.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/qk_norm_rope_fusion.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/qk_norm_rope_fusion.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/sequence_parallelism.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/sequence_parallelism.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
compilation/__pycache__/vllm_inductor_pass.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/vllm_inductor_pass.cpython-312.pyc
Normal file
Binary file not shown.
BIN
compilation/__pycache__/wrapper.cpython-312.pyc
Normal file
BIN
compilation/__pycache__/wrapper.cpython-312.pyc
Normal file
Binary file not shown.
209
compilation/activation_quant_fusion.py
Normal file
209
compilation/activation_quant_fusion.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
)
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, (
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
)
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(kFp8StaticTensorSym)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
result_quant = self.quant_matcher(result_silu_mul, scale)
|
||||
return result_quant[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
output_shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 32), # result
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass"
|
||||
)
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
)
|
||||
759
compilation/backends.py
Normal file
759
compilation/backends.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import operator
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .caching import VllmSerializableFunction
|
||||
from .compiler_interface import (
|
||||
CompilerInterface,
|
||||
EagerAdaptor,
|
||||
InductorAdaptor,
|
||||
InductorStandaloneAdaptor,
|
||||
is_compile_cache_enabled,
|
||||
)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
# from .pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.backend == "inductor":
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (
|
||||
envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")
|
||||
):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor(
|
||||
compilation_config.compile_cache_save_format
|
||||
)
|
||||
else:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
assert compilation_config.backend == "eager", (
|
||||
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
"""
|
||||
A manager to manage the compilation process, including
|
||||
caching the compiled graph, loading the compiled graph,
|
||||
and compiling the graph.
|
||||
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
When serializing the cache, we save it to a Python file
|
||||
for readability. We don't use json here because json doesn't
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: dict[tuple[int | None, int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
@contextmanager
|
||||
def compile_context(self, runtime_shape: int | None = None):
|
||||
"""Provide compilation context for the duration of compilation to set
|
||||
any torch global properties we want to scope to a single Inductor
|
||||
compilation (e.g. partition rules, pass context)."""
|
||||
with pass_context(runtime_shape):
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
with inductor_partition_rule_context(
|
||||
self.compilation_config.splitting_ops
|
||||
):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
The organization of the cache directory is as follows:
|
||||
cache_dir=/path/to/hash_str/rank_i_j/prefix/
|
||||
inside cache_dir, there will be:
|
||||
- vllm_compile_cache.py
|
||||
- computation_graph.py
|
||||
- transformed_code.py
|
||||
|
||||
for multiple prefixes, they can share the same
|
||||
base cache dir of /path/to/hash_str/rank_i_j/ ,
|
||||
to store some common compilation artifacts.
|
||||
"""
|
||||
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
with open(self.cache_file_path) as f:
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
return
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
data = printer.pformat(self.cache)
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable | None:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, runtime_shape
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
# there can be multiple graphs due to piecewise compilation.
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s",
|
||||
elapsed,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s",
|
||||
str(runtime_shape),
|
||||
elapsed,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
if isinstance(self.compiler, InductorAdaptor):
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
|
||||
with self.compile_context(runtime_shape):
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
runtime_shape,
|
||||
maybe_key,
|
||||
)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Cache the graph for dynamic shape for later use", scope="local"
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Cache the graph of shape %s for later use",
|
||||
str(runtime_shape),
|
||||
scope="local",
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, splitting_ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id: dict[fx.Node, int] = {}
|
||||
split_op_graphs: list[int] = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
|
||||
# Check if this is a getitem operation on a node from an earlier subgraph.
|
||||
# If so, assign it to the same subgraph as its input to avoid passing entire
|
||||
# tuple as input to submodules, which is against standalone_compile and
|
||||
# AoTAutograd input requirement.
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
# Assign this getitem to the same subgraph as its input
|
||||
input_node = node.args[0]
|
||||
if input_node.op != "placeholder":
|
||||
assert input_node in node_to_subgraph_id
|
||||
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
|
||||
continue
|
||||
|
||||
if should_split(node, splitting_ops):
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
# recursive child module or the root module
|
||||
continue
|
||||
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by integer graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.vllm_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
)
|
||||
)
|
||||
# Lazy import here to avoid circular import
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod,
|
||||
self.vllm_config,
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.vllm_backend,
|
||||
)
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and not self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=submod.forward,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# the tag for the part of model being compiled,
|
||||
# e.g. backbone/eagle_head
|
||||
model_tag: str = "backbone"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert tag != model_tag, (
|
||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
)
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model_tag = old_tag
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with vLLM.
|
||||
It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
|
||||
where we customize the compilation.
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also adds the PostGradPassManager to Inductor config,
|
||||
which handles the post-grad passes.
|
||||
"""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
compilation_config: CompilationConfig
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. language_model, vision_model, etc.
|
||||
# when multiple parts are initialized as independent
|
||||
# models, we need to use the model_tag to distinguish
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
# self.post_grad_pass_manager = PostGradPassManager()
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config
|
||||
)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (
|
||||
inductor_config[PASS_KEY].uuid()
|
||||
== self.post_grad_pass_manager.uuid()
|
||||
)
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs,**kwargs
|
||||
) -> VllmSerializableFunction:
|
||||
from .caching import _compute_code_hash, compilation_config_hash_factors
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = compilation_config_hash_factors(vllm_config)
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
code_hash = _compute_code_hash(self.compilation_config.traced_files)
|
||||
self.compilation_config.traced_files.clear()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
hash_key,
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
disable_cache = not is_compile_cache_enabled(
|
||||
self.compilation_config.inductor_compile_config
|
||||
)
|
||||
|
||||
if disable_cache:
|
||||
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.compiler_manager.initialize_cache(
|
||||
local_cache_dir, disable_cache, self.prefix
|
||||
)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
from .monitor import torch_compile_start_time
|
||||
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info_once(
|
||||
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
|
||||
)
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
# called once
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
# self.configure_post_pass()
|
||||
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
||||
fx_split_ops: list[str] = []
|
||||
else:
|
||||
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
# depyf will hook lazy_format_graph_code and dump the graph
|
||||
# for debugging, no need to print the graph here
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name
|
||||
for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from
|
||||
# https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
|
||||
# use `print_readable` because it can include submodules
|
||||
src = (
|
||||
"from __future__ import annotations\nimport torch\n"
|
||||
+ self.split_gm.print_readable(print_output=False)
|
||||
)
|
||||
src = src.replace("<lambda>", "GraphModule")
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug_once(
|
||||
"Computation graph saved to %s", graph_path, scope="local"
|
||||
)
|
||||
|
||||
self._called = True
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
or not self.compilation_config.cudagraph_copy_inputs
|
||||
):
|
||||
return VllmSerializableFunction(
|
||||
graph, example_inputs, self.prefix, self.split_gm
|
||||
)
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
self.sym_tensor_indices = [
|
||||
i
|
||||
for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and any(is_symbolic(d) for d in x.size())
|
||||
]
|
||||
|
||||
# compiler managed cudagraph input buffers
|
||||
# we assume the first run with symbolic shapes
|
||||
# has the maximum size among all the tensors
|
||||
self.input_buffers = [
|
||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||
]
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def copy_and_call(*args):
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||
|
||||
# copy the tensor to the static buffer
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
|
||||
# replace the tensor in the list_args to the static buffer
|
||||
list_args[index] = static_tensor
|
||||
return self.split_gm(*list_args)
|
||||
|
||||
return VllmSerializableFunction(
|
||||
graph, example_inputs, self.prefix, copy_and_call
|
||||
)
|
||||
57
compilation/base_static_graph.py
Normal file
57
compilation/base_static_graph.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
178
compilation/caching.py
Normal file
178
compilation/caching.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
except ImportError:
|
||||
SerializableCallable = object
|
||||
|
||||
assert isinstance(SerializableCallable, type)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable):
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
inputs to the compiled function and return the result.
|
||||
It also implements a serialization interface to support PyTorch's precompile
|
||||
with custom backend, so that we can save and load the compiled function on
|
||||
disk. There's no need to wrap around the compiled function if we don't want
|
||||
to serialize them in particular cases.
|
||||
Right now serialization for the custom backend is done via
|
||||
serializing the Dynamo fx graph plus example inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, graph_module, example_inputs, prefix, optimized_call):
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
self.example_inputs = example_inputs
|
||||
self.prefix = prefix
|
||||
self.optimized_call = optimized_call
|
||||
self.shape_env = None
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
if sym_input is not None:
|
||||
self.shape_env = sym_input.node.shape_env
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
import sympy
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
def _graph_reducer_override(self, obj):
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
and hasattr(obj, "_torch_unpickler")
|
||||
):
|
||||
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||
if isinstance(obj, FakeTensorMode):
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
# Mask off tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor, lambda _: None, state["example_inputs"]
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
return pickle.dumps(state)
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||
from torch._guards import TracingContext, tracing
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])
|
||||
|
||||
def optimized_call(*example_inputs):
|
||||
"""
|
||||
On the first run of the optimized call, we rerun the compiler
|
||||
backend which should result in a cache hit. After the backend
|
||||
call returns, we just do a one-time replacement of the optimized
|
||||
call with the compiled function, so that subsequent calls are on
|
||||
the AOT compiled path.
|
||||
"""
|
||||
compile_inputs = [
|
||||
inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)
|
||||
]
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
fn.optimized_call = vllm_backend(
|
||||
state["graph_module"], compile_inputs
|
||||
).optimized_call
|
||||
return fn.optimized_call(*example_inputs)
|
||||
|
||||
fn = cls(**state, optimized_call=optimized_call)
|
||||
return fn
|
||||
|
||||
@property
|
||||
def co_name(self):
|
||||
"""
|
||||
Used for depyf debugging.
|
||||
"""
|
||||
return "VllmSerializableFunction"
|
||||
|
||||
|
||||
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = envs.compute_hash()
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
return factors
|
||||
|
||||
|
||||
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
|
||||
hash_content = []
|
||||
for filepath, content in items:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
return hashlib.md5(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def _compute_code_hash(files: set[str]) -> str:
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
|
||||
)
|
||||
file_contents = {}
|
||||
for filepath in files:
|
||||
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||
if not os.path.isfile(filepath):
|
||||
file_contents[filepath] = ""
|
||||
else:
|
||||
with open(filepath) as f:
|
||||
file_contents[filepath] = f.read()
|
||||
return _compute_code_hash_with_content(file_contents)
|
||||
1234
compilation/collective_fusion.py
Normal file
1234
compilation/collective_fusion.py
Normal file
File diff suppressed because it is too large
Load Diff
639
compilation/compiler_interface.py
Normal file
639
compilation/compiler_interface.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
def is_compile_cache_enabled(
|
||||
vllm_additional_inductor_config: dict[str, Any],
|
||||
) -> bool:
|
||||
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
|
||||
"force_disable_caches", False
|
||||
)
|
||||
|
||||
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
|
||||
# with torch.compiler.config.force_disable_caches when minimum PyTorch
|
||||
# version reaches 2.10
|
||||
return (
|
||||
not envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
and not torch._inductor.config.force_disable_caches
|
||||
and not vllm_inductor_config_disable_cache
|
||||
)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]):
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
# set_functorch_config()
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
compiled_graph.save(path=path, format=self.save_format)
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a subdirectory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
# set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir
|
||||
):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# Turn off the checks if we disable the compilation cache.
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. "
|
||||
)
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph"
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different mode of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config():
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
48
compilation/counter.py
Normal file
48
compilation/counter.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
216
compilation/cuda_graph.py
Normal file
216
compilation/cuda_graph.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: torch.cuda.CUDAGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def weak_ref_tensors_with_intermediate(self, output):
|
||||
if isinstance(output, IntermediateTensors):
|
||||
intermediate_states = IntermediateTensors(
|
||||
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
|
||||
return intermediate_states
|
||||
return weak_ref_tensors(output)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug(
|
||||
"Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
571
compilation/decorators.py
Normal file
571
compilation/decorators.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: _T) -> _T:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: _T | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[_T], _T] | _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
|
||||
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
|
||||
# to avoid too much indentation for `_support_torch_compile`
|
||||
if not hasattr(cls, "forward"):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
IntermediateTensors,
|
||||
IntermediateTensors | None,
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(
|
||||
("Inferred dynamic dimensions for forward method of %s: %s"),
|
||||
cls,
|
||||
list(inferred_dynamic_arg_dims.keys()),
|
||||
)
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly."
|
||||
)
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _model_hash_key(fn) -> str:
|
||||
import vllm
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(vllm.__version__.encode())
|
||||
sha256_hash.update(fn.__qualname__.encode())
|
||||
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _verify_source_unchanged(source_info, vllm_config) -> None:
|
||||
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||
|
||||
file_contents = {}
|
||||
for source in source_info.inlined_sources:
|
||||
module = sys.modules[source.module]
|
||||
file = inspect.getfile(module)
|
||||
vllm_config.compilation_config.traced_files.add(file)
|
||||
file_contents[file] = source.content
|
||||
expected_checksum = _compute_code_hash_with_content(file_contents)
|
||||
actual_checksum = _compute_code_hash(set(file_contents.keys()))
|
||||
if expected_checksum != actual_checksum:
|
||||
raise RuntimeError(
|
||||
"Source code has changed since the last compilation. Recompiling the model."
|
||||
)
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
|
||||
):
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch
|
||||
# it
|
||||
sig = inspect.signature(old_init)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
vllm_config.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or not supports_dynamo()
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, just call it.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
"""
|
||||
When using torch.compile in AOT mode, we store the cache artifacts
|
||||
under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
|
||||
contains all of the factors except for the source files being
|
||||
traced through, because we don't actually know which source files
|
||||
to check at this point (before dynamo runs).
|
||||
On loading we will actually look at the source files being traced
|
||||
through. If any source file have changed (compared with the
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
from .caching import compilation_config_hash_factors
|
||||
|
||||
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_aot_compile",
|
||||
hash_key,
|
||||
)
|
||||
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
loaded_fn = torch.compiler.load_compiled_function(f)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
logger.warning(
|
||||
"Cannot load aot compilation from path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert not envs.VLLM_USE_AOT_COMPILE
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
original_code_object.co_filename
|
||||
)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
),
|
||||
)
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (
|
||||
torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses",
|
||||
[
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
_ColumnvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
|
||||
),
|
||||
):
|
||||
yield
|
||||
253
compilation/fix_functionalization.py
Normal file
253
compilation/fix_functionalization.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
# just use mm_node directly.
|
||||
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
else:
|
||||
# Directly replace the auto_functionalize(rotary_embedding)
|
||||
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, torch.fx.Node | str],
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), (
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
)
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
graph.call_function(function, args=args)
|
||||
374
compilation/fusion.py
Normal file
374
compilation/fusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
)
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, False
|
||||
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, True
|
||||
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, False
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms)
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
)
|
||||
359
compilation/fusion_attn.py
Normal file
359
compilation/fusion_attn.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
91
compilation/fx_utils.py
Normal file
91
compilation/fx_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first specified node with the given op
|
||||
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_specified_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
133
compilation/inductor_pass.py
Normal file
133
compilation/inductor_pass.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import (
|
||||
Torch25CustomGraphPass as CustomGraphPass,
|
||||
)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, runtime_shape: int | None):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: int | None):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: str | Any):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable(self, shape: int | None):
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None):
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
317
compilation/matcher_utils.py
Normal file
317
compilation/matcher_utils.py
Normal file
@@ -0,0 +1,317 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool):
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args, **kws):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args, **kws):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kws):
|
||||
return self.forward(*args, **kws)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
||||
|
||||
def empty_int64(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
|
||||
|
||||
def empty_f32(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
is_neox: bool,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
self.head_size = head_size
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.kv_size = self.num_kv_heads * self.head_size
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
positions = self.empty_int64(5)
|
||||
query = self.empty(5, self.q_size)
|
||||
key = self.empty(5, self.kv_size)
|
||||
cos_sin_cache = self.empty(4096, self.rotary_dim)
|
||||
return [positions, query, key, cos_sin_cache]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result = auto_functionalized(
|
||||
self.rotary_op,
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
)
|
||||
query_out = result[1]
|
||||
key_out = result[2] if len(result) > 2 else None
|
||||
return query_out, key_out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
# TODO: support non-contiguous input for RMSNorm and remove this
|
||||
input_contiguous = input.contiguous()
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input_contiguous,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
_, result, residual = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
|
||||
def make_scale(self, input: torch.Tensor):
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
|
||||
|
||||
class MatcherSiluAndMul(MatcherCustomOp):
|
||||
def __init__(self, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = SiluAndMul.enabled()
|
||||
super().__init__(enabled)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 4)
|
||||
return [input]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
|
||||
return result[1]
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return SiluAndMul.forward_native(x)
|
||||
62
compilation/monitor.py
Normal file
62
compilation/monitor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
"torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled():
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError(
|
||||
"CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled."
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
134
compilation/noop_elimination.py
Normal file
134
compilation/noop_elimination.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# remove reshape/slice if it produces the original shape
|
||||
if is_func(node, torch.ops.aten.reshape.default) or is_func(
|
||||
node, torch.ops.aten.slice.Tensor
|
||||
):
|
||||
input = node.args[0]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
if self.all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if self.all_dims_equivalent(base_shape, view_shape):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Shape comparison helpers ----------------------
|
||||
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1
|
||||
if isinstance(i_dim, int) and isinstance(dim, int):
|
||||
return dim == i_dim
|
||||
# Case 2
|
||||
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
|
||||
return dim == i_dim
|
||||
return False
|
||||
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||
) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
# Different ranks can't be equivalent
|
||||
return False
|
||||
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
72
compilation/partition_rules.py
Normal file
72
compilation/partition_rules.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(splitting_ops: list[str]):
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
Inductor scheduler to partition the graph at these operators. The rules
|
||||
are automatically restored to their previous state on exit.
|
||||
|
||||
Args:
|
||||
splitting_ops: List of operator names to partition on.
|
||||
"""
|
||||
if not splitting_ops:
|
||||
logger.debug("No partition ops provided; skipping rule registration.")
|
||||
yield
|
||||
return
|
||||
|
||||
# Save current state before registering
|
||||
|
||||
saved_splitting_ops: list[str] = list(
|
||||
torch._inductor.config.custom_should_partition_ops
|
||||
)
|
||||
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
||||
|
||||
logger.debug(
|
||||
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clear and restore previous state
|
||||
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
||||
logger.debug("Restored previous partition rules state.")
|
||||
135
compilation/pass_manager.py
Normal file
135
compilation/pass_manager.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn):
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable(shape):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with shape %s", pass_, shape)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
return InductorPass.hash_dict(state)
|
||||
121
compilation/piecewise_backend.py
Normal file
121
compilation/piecewise_backend.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
return entry.runnable(*args)
|
||||
21
compilation/post_cleanup.py
Normal file
21
compilation/post_cleanup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
238
compilation/qk_norm_rope_fusion.py
Normal file
238
compilation/qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from .fusion import empty_bf16, empty_fp32, empty_i64
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
Match the unfused sequence in attention blocks and replace with the fused op.
|
||||
|
||||
Unfused (conceptually):
|
||||
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
qh = reshape(q, [-1, num_heads, head_dim])
|
||||
kh = reshape(k, [-1, num_kv_heads, head_dim])
|
||||
qn = rms_norm(qh, q_weight, eps)
|
||||
kn = rms_norm(kh, k_weight, eps)
|
||||
qf = reshape(qn, [-1, num_heads * head_dim])
|
||||
kf = reshape(kn, [-1, num_kv_heads * head_dim])
|
||||
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
|
||||
return qf, kf, v
|
||||
|
||||
Fused replacement:
|
||||
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
|
||||
eps, q_weight, k_weight, cos_sin_cache, is_neox,
|
||||
positions.view(-1))
|
||||
return split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
rope_flashinfer: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=is_neox,
|
||||
head_size=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self):
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
positions = empty_i64(T)
|
||||
q_weight = empty_bf16(1, self.head_dim)
|
||||
k_weight = empty_bf16(1, self.head_dim)
|
||||
if self.rope_flashinfer:
|
||||
cos_sin_cache = empty_fp32(4096, self.head_dim)
|
||||
else:
|
||||
cos_sin_cache = empty_bf16(4096, self.head_dim)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q path: view -> RMS -> view back to q.shape
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
qkv=qkv,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=self.num_kv_heads,
|
||||
num_heads_v=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
eps=self.eps,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
position_ids=positions.view(-1),
|
||||
)
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
QkNormRopePattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
QkNormRopePattern.fx_view_to_reshape,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
)
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
|
||||
)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
config, Attention
|
||||
)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for neox in [True, False]:
|
||||
if RotaryEmbedding.enabled():
|
||||
for rope_flashinfer in [False, True]:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
rope_flashinfer=rope_flashinfer,
|
||||
).register(self.patterns)
|
||||
else:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
363
compilation/sequence_parallelism.py
Normal file
363
compilation/sequence_parallelism.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_first_out_wrapper(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args):
|
||||
return fn(*args)[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Used to cleanup redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||
# needs to be split along the sequence dimension. However, this dimension
|
||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||
# is not supported.
|
||||
#
|
||||
# This pass is therefore only applied when the sequence dimension is
|
||||
# concrete:
|
||||
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
|
||||
# For this case we always pad num_tokens to be a multiple of
|
||||
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
|
||||
# 2. For specific shape provided during compilation (e.g., from
|
||||
# `compile_sizes`), which must be divisible by the tensor-parallel
|
||||
# size.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
44
compilation/torch25_custom_graph_pass.py
Normal file
44
compilation/torch25_custom_graph_pass.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
|
||||
"""
|
||||
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
|
||||
It conforms to the 2.6 interface but also supports pickling, as that's what
|
||||
the inductor code cache uses to determine the cache key before 2.6.
|
||||
(in 2.6 and above, uuid() is used.)
|
||||
|
||||
Subclasses can just "pretend" that uuid is used.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
"""
|
||||
Implementation of the custom pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Any | None:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom pass implementation.
|
||||
Return None to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
|
||||
to enable subclasses to only have to implement uuid.
|
||||
"""
|
||||
return self.uuid()
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError(
|
||||
"Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
" native uuid support for custom passes."
|
||||
)
|
||||
173
compilation/vllm_inductor_pass.py
Normal file
173
compilation/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InductorCompilationConfig:
|
||||
splitting_ops: list[str] | None = None
|
||||
use_inductor_graph_partition: bool = False
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
dump_prefix: ClassVar[int | None] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||
self.compilation_config = InductorCompilationConfig(
|
||||
splitting_ops=config.compilation_config.splitting_ops,
|
||||
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
|
||||
)
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compile_debug_dump_path()
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils.system_utils import unique_filepath
|
||||
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
|
||||
)
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f"# This file was produced by VllmPatternMatcherPass."
|
||||
f"dump_patterns for {self.pass_name}.\n"
|
||||
f"# It does its best to produce valid-Python-looking code but"
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join(
|
||||
[f"def pattern_{i}():"]
|
||||
+ [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
]
|
||||
+ [f"return {out_node}"]
|
||||
).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, self.name)
|
||||
238
compilation/wrapper.py
Normal file
238
compilation/wrapper.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _noop_add_global_state_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context():
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
mode = vllm_config.compilation_config.mode
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode = None
|
||||
|
||||
def aot_compile(self, *args, **kwargs):
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
+ "Please make sure torch.compile is enabled with the latest "
|
||||
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self.forward(*args, **kwargs)
|
||||
else:
|
||||
with _compilation_context():
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs): ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
decompiled_file = path / "transformed_code.py"
|
||||
if not decompiled_file.exists():
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and "update" in new_code.co_names
|
||||
):
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
msg = (
|
||||
"Assigning / modifying buffers of nn.Module during forward pass is not "
|
||||
"allowed when using cudagraph inside the compiler because it will "
|
||||
"cause silent errors. Please use eager mode or fix the code. The "
|
||||
"following code contains clues about which buffer is being modified "
|
||||
f"(please search for the usage of the function `update`):\n{src}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def _dispatch_to_compiled_code(self):
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
Reference in New Issue
Block a user