[Fusion] [Graph]Add Matmul Allreduce Rmsnorm fusion Pass (#5034)

This PR add `MatmulAllreduceRmsnorm` operator and introduces a graph
fusion pass for `matmul_allreduce_rmsnorm` operations. The
implementation includes a new configuration flag, a pattern matching
pass using `torch._inductor.pattern_matcher`.

Co-authored-by: Trunrain [270250579@qq.com](mailto:270250579@qq.com)

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
Icey
2026-01-19 09:28:07 +08:00
committed by GitHub
parent 9cad1a8349
commit c929bd1e8d
8 changed files with 251 additions and 1 deletions

View File

@@ -169,7 +169,9 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""
def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
def __init__(
self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, fuse_allreduce_rms: bool = False, **kwargs
):
"""
Initialize the configuration.
@@ -179,10 +181,13 @@ class AscendCompilationConfig:
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: False
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope
self.fuse_allreduce_rms = fuse_allreduce_rms
class XliteGraphConfig:

View File

@@ -58,3 +58,8 @@ class GraphFusionPassManager:
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))
if self.ascend_compilation_config.get("fuse_allreduce_rms", True):
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass
self.passes.append(MatmulAllReduceAddRMSNormPass(config))

View File

@@ -0,0 +1,153 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import logger
# computation-communication tiling block is 512
ALLREDUCE_NORM_FUSE_THREHOLD = 512
class MiddleLayerMatmulAllReduceAddRMSNormPattern:
"""
recognizing the Matmul+AllReduce+AddRMSNorm computation pattern
AllReduce is optimized in the fusion operator to a two-stage communication of ReduceScatter+AllGather
"""
def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()
def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
out0 = output[0]
out1 = output[2]
return out0, out1
def replacement(x, weight, residual, rms_norm_weight):
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
)
return out0, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class LastLayerMatmulAllReduceAddRMSNormPattern:
def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()
def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
return output[0]
def replacement(x, weight, residual, rms_norm_weight):
out0, _ = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
)
return out0
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class MatmulAllReduceAddRMSNormPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="allreduce_rmsnorm_fusion_pass")
MiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register(self.pattern_match_passes)
LastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
pattern_idx = 0
for pattern_entry in self.pattern_match_passes.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
return applicable

View File

@@ -127,12 +127,15 @@
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
# Why:
# vllm doesn't support all_to_all for GroupCoordinator.
# all_reduce in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
# How
# Add all_to_all implementation for GroupCoordinator.
# make all_reduce as a customop.
# Related PR (if no, explain why):
# No, we should use vlLM all2all manager to support all_to_all for npu.
# Future Plan:
# Remove this patch when the refactor of all2all manager is done.
# Remove this patch when vLLM support all_reduce as customop.
#
# ** 3. File: worker/patch_minicpm.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -276,3 +279,12 @@
# Future Plan:
# Remove this patch when cann fix the gather bug.
#
# ** 13. File: worker/patch_unquantized_gemm.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.utils.default_unquantized_gemm`
# Why:
# unquantized_gemm in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
# How
# make unquantized_gemm as a customop.
# Future Plan:
# Remove this patch when vLLM support the operator as customop.

View File

@@ -22,6 +22,7 @@ if HAS_TRITON:
# isort: off
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
import vllm_ascend.patch.worker.patch_bert # noqa
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_deepseek # noqa

View File

@@ -112,5 +112,10 @@ class GroupCoordinatorPatch(GroupCoordinator):
gather_dim, scatter_sizes,
gather_sizes)
def all_reduce(self, input_):
if self.world_size == 1:
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch

View File

@@ -0,0 +1,57 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import vllm.model_executor.layers.utils
from vllm.utils.torch_utils import direct_register_custom_op
def unquantized_gemm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.nn.functional.linear(x, weight, bias)
def unquantized_gemm_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_shape = (x.shape[0], weight.shape[0])
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
direct_register_custom_op(op_name="unquantized_gemm",
op_func=unquantized_gemm,
fake_impl=unquantized_gemm_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
def default_unquantized_gemm(
layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if x.device.type == "npu":
return torch.ops.vllm.unquantized_gemm(x, weight, bias)
else:
return torch.nn.functional.linear(x, weight, bias)
vllm.model_executor.layers.utils.default_unquantized_gemm = default_unquantized_gemm

View File

@@ -192,6 +192,18 @@ class NPUPlatform(Platform):
else ascend_compilation_config
)
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
if model_config is None: