Sync from v0.13
This commit is contained in:
359
vllm/compilation/fusion_attn.py
Normal file
359
vllm/compilation/fusion_attn.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
Reference in New Issue
Block a user