Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
FUSED_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
}
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
) -> None:
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, (
f"unsupported fusion scheme {self.quant_key}"
)
self.FUSED_OP = FUSED_OPS[self.quant_key]
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP, result=result, input=input, scale=scale
)
return at[1]
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self) -> None:
super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
is_sf_swizzled_layout=True,
)
return at[1], at[2]
def replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized(
self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale,
)
return at[1], at[2]
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass"
)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
)

View File

@@ -0,0 +1,862 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from importlib.util import find_spec
from types import ModuleType
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
import flashinfer.comm as _flashinfer_comm
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
_flashinfer_comm, "create_allreduce_fusion_workspace"
):
flashinfer_comm = _flashinfer_comm
except ImportError:
pass
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}
# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
90: {
2: 32, # 32MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 32, # 32MB
4: 4, # 4MB
8: 1, # 1MB
},
}
if flashinfer_comm is not None:
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
MiB = 1024 * 1024
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
curr_device = current_platform.get_device_capability()
device_capability = curr_device.to_int() if curr_device is not None else None
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type, unused-ignore]
{},
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
layout_code = None
# layout_code only supported by trtllm backend
if workspace.backend == "trtllm":
# in vllm we only support swizzled layout
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=allreduce_in,
workspace=workspace,
pattern=pattern_code,
launch_with_pdl=launch_with_pdl,
output=None,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
residual_in=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="flashinfer_trtllm_fused_allreduce_norm",
op_func=call_trtllm_fused_allreduce_norm,
mutates_args=[
"allreduce_in",
"residual",
"norm_out",
"quant_out",
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
world_size: int,
max_token_num: int = 1024,
) -> None:
self.world_size = world_size
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
return {
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
}
# TODO(luka): unify
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
with fused flashinfer implementation.
Applies to allreduce + rmsnorm before attn in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight)
return rms, allreduce_output
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# rms_result, allreduce_in
return allreduce[3], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# allreduce_in, residual
return allreduce[1], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output
return allreduce[4], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
return quant, res
def replacement(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual
return allreduce[4], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [input, quant_result, weight, input_global_scale, output_scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
def replacement(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output, output_scale
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], residual, quant_out_tuple[2]
def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual, output_scale
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
rank = get_tensor_model_parallel_rank()
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass"
)
return
max_size = config.compilation_config.pass_config.flashinfer_max_size(
self.tp_size
)
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size,
)
return
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
self.max_token_num = min(
self.max_token_num, config.scheduler_config.max_num_batched_tokens
)
logger.debug_once(
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
"Maximal number of tokens used by "
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
scope="global",
)
for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
max_token_num=self.max_token_num,
)
self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
if supports_quantization:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
with contextlib.suppress(Exception):
destroy_fi_ar_workspace()

View File

@@ -0,0 +1,374 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
) -> None:
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
is_sf_swizzled_layout=True,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
empty_i32(
128, round_up(self.num_heads * self.head_size // 16, 4)
), # output_scale
empty_fp32(1, 1), # input_scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype
)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype
)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern,
)

View File

@@ -0,0 +1,423 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(
input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
return torch.ops.aten._scaled_mm.default(
all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
return cutlass_scaled_mm[1]
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRotaryEmbedding(MatcherCustomOp):
def __init__(
self,
is_neox: bool,
head_size: int,
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
match_rocm_aiter: bool | None = None,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
if match_rocm_aiter is None:
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
super().__init__(enabled)
self.is_neox = is_neox
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_size
self.kv_size = self.num_kv_heads * self.head_size
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
elif match_rocm_aiter:
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
else:
self.rotary_op = ROTARY_OP
def inputs(self) -> list[torch.Tensor]:
positions = self.empty_int64(5)
query = self.empty(5, self.q_size)
key = self.empty(5, self.kv_size)
cos_sin_cache = self.empty(4096, self.rotary_dim)
return [positions, query, key, cos_sin_cache]
def forward_custom(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result = auto_functionalized(
self.rotary_op,
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
)
query_out = result[1]
key_out = result[2] if len(result) > 2 else None
return query_out, key_out
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result: tuple[torch.Tensor, torch.Tensor | None] = (
RotaryEmbedding.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox,
)
)
return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self.match_rocm_aiter = match_rocm_aiter
self._rmsnorm_op = RMS_ADD_OP
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight, residual)
_, result, residual = auto_functionalized(
self._rmsnorm_op,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
return result
class MatcherQuantFP8(MatcherCustomOp):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
is_tma_aligned: bool = False,
) -> None:
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
self.match_rocm_aiter = match_rocm_aiter
self.is_tma_aligned = is_tma_aligned
if match_rocm_aiter:
assert not quant_key.scale.group_shape.is_per_tensor(), (
"ROCm aiter fusion pass does not support per tensor quantization"
)
if quant_key.scale.group_shape.is_per_token():
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
else:
assert quant_key.scale.group_shape.col == 128, (
"ROCm aiter fusion pass currently supports "
"quantization operation with group_size 128"
)
if current_platform.is_fp8_fnuz():
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
else:
self.QUANT_OP = (
torch.ops.vllm.triton_per_token_group_quant_fp8.default
)
else:
assert quant_key in QUANT_OPS, (
f"unsupported quantization scheme {quant_key}"
)
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
tma_aligned_scales=self.is_tma_aligned,
compile_native=False,
)
def forward_rocm_aiter(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP( # type: ignore[no-any-return]
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, scale)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
# for tma_aligned, the scale must be passed to forward_custom
# tma_aligned fusion then matches by custom op arguments
if not self.is_tma_aligned:
assert scale is None
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None) -> None:
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 4)
return [input]
def forward_custom(
self,
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
return result[1]
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
return SiluAndMul.forward_native(x)

View File

@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern:
"""
Match the unfused sequence in attention blocks and replace with the fused op.
Unfused (conceptually):
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
qh = reshape(q, [-1, num_heads, head_dim])
kh = reshape(k, [-1, num_kv_heads, head_dim])
qn = rms_norm(qh, q_weight, eps)
kn = rms_norm(kh, k_weight, eps)
qf = reshape(qn, [-1, num_heads * head_dim])
kf = reshape(kn, [-1, num_kv_heads * head_dim])
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
return qf, kf, v
Fused replacement:
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
eps, q_weight, k_weight, cos_sin_cache, is_neox,
positions.view(-1))
return split(qkv, [qsz, kvsz, kvsz], -1)
"""
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=is_neox,
head_size=self.head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
positions = empty_i64(T)
q_weight = empty_bf16(1, self.head_dim)
k_weight = empty_bf16(1, self.head_dim)
if self.rope_flashinfer:
cos_sin_cache = empty_fp32(4096, self.head_dim)
else:
cos_sin_cache = empty_bf16(4096, self.head_dim)
return [
qkv,
positions,
q_weight,
k_weight,
cos_sin_cache,
]
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q path: view -> RMS -> view back to q.shape
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
qkv=qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.eps,
q_weight=q_weight,
k_weight=k_weight,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
position_ids=positions.view(-1),
)
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
QkNormRopePattern.wrap_trace_fn(
pm.fwd_only,
QkNormRopePattern.fx_view_to_reshape,
),
pm_pass,
)
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
)
dtype = config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.warning_once(
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
)
return
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
if len(attn_layers) == 0:
logger.warning_once(
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@@ -0,0 +1,643 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, NamedTuple
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
"""
Named tuple for identifying the type of RMSNorm + quant fusion.
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self) -> str:
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
)
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(
kFp8StaticTensorSym, False
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8StaticTensorSym, True
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, False
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
class RMSNormQuantPattern:
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
is_tma_aligned: bool = False,
) -> None:
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result
return at[1]
inputs = [
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result, residual
return at[1], at[2]
inputs = [
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.is_e8m0 = is_e8m0
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, residual, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=self.has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, scale
return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .rms_quant_fusion import (
FusedRMSQuantKey,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
class AiterRMSNormQuantPattern:
def __init__(
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
match_rocm_aiter=match_aiter_quant,
)
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1], result[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
AiterRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
for match_aiter_quant in [True, False]:
# Fuse aiter rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterFusedAddRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops import auto_functionalized
from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.attention import (
Attention,
get_attention_context,
)
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherRotaryEmbedding,
)
from .rms_quant_fusion import (
empty_bf16,
empty_i64,
)
logger = init_logger(__name__)
def fused_rope_and_unified_kv_cache_update_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
"""
This impl fetches the KV cache and slot mapping from the forward context,
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update(
attn_layer,
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def fused_rope_and_unified_kv_cache_update_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype)
direct_register_custom_op(
op_name="fused_rope_and_unified_kv_cache_update",
op_func=fused_rope_and_unified_kv_cache_update_impl,
mutates_args=["query", "key"],
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
)
class RopeReshapeKVCachePattern:
"""
This pattern matches the following unfused inplace ops:
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
and replaces it with the fused inplace op:
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
q, k, v, positions, cos_sin_cache, is_neox, layer_name
)
"""
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
def __init__(
self,
layer: Attention,
is_neox: bool,
) -> None:
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.num_kv_heads = layer.num_kv_heads
self.head_size = layer.head_size
self.head_size_v = layer.head_size_v
self.is_neox = is_neox
self.q_size = self.num_heads * self.head_size
self.k_size = self.num_kv_heads * self.head_size
self.v_size = self.num_kv_heads * self.head_size_v
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=self.is_neox,
head_size=self.head_size,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size)
return [
qkv,
positions,
cos_sin_cache,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=self.layer_name,
)
return results[0], results[1], results[2], v
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
gm = pm.fwd_only(*args, **kwargs)
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
)
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
"""
This pass fuses the rotary embedding and KV cache update operations
into a single fused kernel if available.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
This fusion eliminates the need for separate kernel launches and
intermediate memory operations between the RoPE and cache update steps.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rope_kv_cache_fusion_pass"
)
cc = config.compilation_config
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention)
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]:
RopeReshapeKVCachePattern(
layer=layer,
is_neox=is_neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass works best for the small-batch decode setting.
# For large-batch e.g. prefill, it is better to use two separate kernels
# since they are compute bound and the fused kernels require further tuning.
return compile_range.end <= self.max_token_num
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)

View File

@@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
logger = init_logger(__name__)
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
}
# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
}
def get_sequence_parallelism_threshold(
hidden_size: int,
tp_size: int,
element_size: int,
) -> int | None:
"""
Calculate the minimum token threshold for applying sequence parallelism.
Returns None if sequence parallelism should not be applied based on model size.
Branching logic based on device capability:
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
- If not, returns None (SP disabled for small models on this device)
- If yes, calculates threshold based on per-GPU size
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
(hidden_size * element_size)
"""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return None
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
# Only apply sequence parallelism for models meeting the size threshold
if hidden_size < min_hidden_size:
return None
MiB = 1024 * 1024
min_size = min_per_gpu_size_mb * MiB * tp_size
return int(min_size // (hidden_size * element_size))
def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0]
return wrapper
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
all_gather = self._all_gather(rmsnorm[0])
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rms, residual_out = self.rmsnorm_matcher(
all_reduce, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rms, residual_out = self.rmsnorm_matcher(
reduce_scatter, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
an RMSNorm (or RMSNorm and then Quantization) operation.
These patterns are replaced with a ReduceScatter operation, followed by
a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is:
Input -> AllReduce -> RMSNorm -> Output
becomes
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements,
it lays the groundwork for subsequent fusion passes, such as
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Get min_token_num threshold
# Read min_token_num from config (calculated during config init)
self.min_token_num = None
if config.model_config is not None:
pass_config = config.compilation_config.pass_config
self.min_token_num = pass_config.sp_min_token_num
if self.min_token_num is not None:
# Take the min to avoid exceeding max_num_batched_tokens
max_batched = config.scheduler_config.max_num_batched_tokens
if max_batched is not None:
self.min_token_num = min(self.min_token_num, max_batched)
logger.debug_once(
f"Sequence parallelism min token threshold: {self.min_token_num}",
scope="global",
)
# Used to clean up redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Determines if sequence parallelism should be applied for the given
compile range.
SP is only beneficial for larger batch sizes where the communication
overhead is amortized. For small batches, the overhead of splitting
and gathering tensors across TP ranks outweighs the benefits.
Returns False (SP disabled) when:
- Using piecewise compilation with non-concrete or TP-indivisible sizes
- min_token_num is None (SP disabled for this device/config)
- The compile range starts below the minimum token threshold
"""
# For piecewise compilation (not using inductor graph partition),
# we need concrete sizes that are divisible by TP for correct splitting
if (
not self.compilation_config.use_inductor_graph_partition
and self.compilation_config.splitting_ops
):
tp_size = get_tensor_model_parallel_world_size()
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
return False
# min_token_num is None when SP is disabled for this device/config
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
if self.min_token_num is None:
return False
# Only apply SP when batch size meets the minimum threshold
return compile_range.start >= self.min_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.node import Target
def is_func(node: fx.Node, target: Target) -> bool:
return bool(node.op == "call_function" and node.target == target)
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools
import hashlib
import inspect
import json
import types
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
if TYPE_CHECKING:
from vllm.config.utils import Range
from torch._inductor.custom_graph_pass import CustomGraphPass
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
class PassContext:
def __init__(self, compile_range: Range):
self.compile_range: Range = compile_range
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(compile_range: Range) -> Generator[None, None, None]:
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(compile_range)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass): # type: ignore[misc]
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> str:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)
@staticmethod
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_range(self, compile_range: Range) -> bool:
return True
class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph) -> None:
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from .fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.scatter_split_replace import ScatterSplitReplacementPass
from .utility.split_coalescing import SplitCoalescingPass
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def __init__(self) -> None:
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph) -> None:
VllmInductorPass.dump_prefix = 0 # reset dump index
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.eliminate_noops:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sp:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.fuse_gemm_comms:
self.passes += [AsyncTPPass(config)]
if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_rope_kvcache:
self.passes += [SplitCoalescingPass(config)]
self.passes += [ScatterSplitReplacementPass(config)]
self.passes += [RopeKVCacheFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass) -> None:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self) -> str:
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
passes = []
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes:
passes.append(pass_.uuid())
passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
state["passes"] = passes
return InductorPass.hash_dict(state)

View File

@@ -0,0 +1,301 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class FixFunctionalizationPass(VllmInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif (
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
and at_target
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
):
mutated_args = {
1: "query",
2: "key",
}
self.defunctionalize(graph, node, mutated_args=mutated_args)
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
elif (
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
and at_target
== torch.ops.vllm.function_with_mutated_args_and_return.default
):
mutated_args = {1: "x"}
self.defunctionalize(graph, node, mutated_args=mutated_args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
) -> None:
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
# as well as mutated args at getitem[1:...]
if idx == 0:
assert idx not in mutated_args, (
f"result at getitem[0] should not be in mutated_args for {node}"
)
continue
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
)
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
fn_node = graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
fn_node = graph.call_function(function, args=args)
# If the function returns a value as well as mutating args inplace,
# the functionalized node will have a getitem[0] user that holds this value
# Replace getitem[0] user of the auto-functionalized node
# with the new defunctionalized node directly if it exists
users = self.getitem_users(node)
if 0 in users:
user = users[0]
user.replace_all_uses_with(fn_node)
self._remove(user)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 3:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# remove reshape/slice if it produces the original shape
if is_func(node, torch.ops.aten.reshape.default) or is_func(
node, torch.ops.aten.slice.Tensor
):
input = node.args[0]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if self.all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Shape comparison helpers ----------------------
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
assignment if there are no users for the inplace tensor written to by
the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs
that are split+getitem outputs of a single qkv tensor.
When functionalized, we fetch the rotated query and key from the functionalized op
using `getitem` calls. However, we also write to the qkv tensor inplace using a
`slice_scatter`, then split the inplace tensor to get the output tensors again.
Instead, if the inplace tensor has no subsequent users, we can just replace the
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but
writing a custom pass for it before defunctionalization allows matching against the
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
"""
import operator
import torch
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class ScatterSplitReplacementPass(VllmInductorPass):
"""Replace getitem+slice_scatter+split nodes with a single getitem when
the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
q = operator.getitem(split_with_sizes_2, 0)
k = operator.getitem(split_with_sizes_2, 1)
v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
v = operator.getitem(split_with_sizes_1, 2)
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
kwargs = node.kwargs
at_target = node.args[0]
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
for user in node.users:
if is_func(user, operator.getitem):
getitem_nodes[user.args[1]] = user
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of a qkv tensor.
# While functionalized, results at [1] and [2] are scattered
# back into qkv, then split again to get query and key.
# If the inplace tensor has no other users, we can replace
# the slice_scatter+split nodes with the original results.
for user in getitem_nodes[1].users:
slice_scatter_1_node = user
if not is_func(
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
):
continue
for user in getitem_nodes[2].users:
slice_scatter_2_node = user
if not is_func(
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
):
continue
for user in slice_scatter_2_node.users:
split_node = user
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
continue
split_getitem_users = {}
for user in split_node.users:
if is_func(user, operator.getitem):
split_getitem_users[user.args[1]] = user
# Replace query node
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
graph.erase_node(split_getitem_users[0])
# Replace key node
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
graph.erase_node(split_getitem_users[1])
# Redirect value node to original qkv tensor
split_getitem_users[2].replace_input_with(split_node, query.args[0])
# Erase unused nodes
graph.erase_node(split_node)
graph.erase_node(slice_scatter_2_node)
graph.erase_node(slice_scatter_1_node)
count += 1
logger.debug("Eliminated %d slice_scatter+split nodes", count)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
graph may contain multiple ``split_with_sizes`` calls on the same tensor
that CSE fails to merge. This pass detects and replaces the duplicates
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
see a single split node with all users attached.
See also:
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
"""
import operator
import torch
from torch import fx
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar
import regex as re
import torch
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .inductor_pass import InductorPass
logger = init_logger(__name__)
@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
dump_prefix: ClassVar[int | None] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
# Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device: str | None = (
config.device_config.device if config.device_config else None
)
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
)
def begin(self) -> None:
self._start_time = time.perf_counter_ns()
def end_and_log(self) -> None:
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
)
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return str(
self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compile_debug_dump_path()
if not debug_dump_path:
return
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils.system_utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
)
with file_path.open("w") as f:
print(
f"# This file was produced by VllmPatternMatcherPass."
f"dump_patterns for {self.pass_name}.\n"
f"# It does its best to produce valid-Python-looking code but"
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join(
[f"def pattern_{i}():"]
+ [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
]
+ [f"return {out_node}"]
).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config)
self.name = name
def __call__(self, graph: torch.fx.Graph) -> None:
self.dump_graph(graph, self.name)