forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
166
vllm/compilation/fusion_attn.py
Normal file
166
vllm/compilation/fusion_attn.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionStaticQuantPattern:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_name: str,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True,
|
||||
):
|
||||
self.layer_name = layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_key = QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric)
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass,
|
||||
layer: Attention):
|
||||
if layer.impl.fused_output_quant_supported(self.quant_dtype,
|
||||
self.quant_key.static,
|
||||
self.quant_key.group_shape):
|
||||
self._register(pm_pass)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_attn,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None)
|
||||
attn_out_view = RESHAPE_OP(at1[1],
|
||||
[-1, self.num_heads * self.head_size])
|
||||
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_quant,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale)
|
||||
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
# That would not work for the unified_attention custom op.
|
||||
with unset_fake_temporarily(), FakeTensorMode():
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads * self.head_size), # attn_output
|
||||
self.empty_quant(5, self.num_heads *
|
||||
self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmInductorPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.static_fwd_ctx = config.compilation_config.static_forward_context
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
for key, layer in self.static_fwd_ctx.items():
|
||||
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
|
||||
layer.head_size,
|
||||
current_platform.fp8_dtype())
|
||||
pattern.register_if_supported(self.patterns, layer)
|
||||
if len(self.static_fwd_ctx) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but "
|
||||
"CompilationConfig.static_forward_context is empty. "
|
||||
"Cannot access attention layers so no fusion "
|
||||
"patterns were registered.")
|
||||
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_attn_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quantization onto %s attention nodes", count)
|
||||
self.dump_graph(graph, "after_attn_fusion")
|
||||
self.end_and_log()
|
||||
Reference in New Issue
Block a user