Files
xc-llm-ascend/vllm_ascend/compilation/npu_graph_ex_pass_manager.py
CodeCat 54e8389f8e [Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex. (#6006)
### What this PR does / why we need it?
This PR builds upon PR
https://github.com/vllm-project/vllm-ascend/pull/5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: cjian <2318164299@qq.com>
2026-01-27 16:41:48 +08:00

66 lines
2.5 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
class NpuGraphEXPassManager:
"""
A pass manager for npu_graph ex fusion passes.
It handles the configuration and execution of passes.
The counterpart in vllm is PostGradPassManager. Since torch_npu
does not support triton for now, we define our own pass manager.
"""
def __init__(self):
self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph) -> fx.Graph:
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
graph.recompiler()
return graph
def add(self, pass_: VllmInductorPass):
assert isinstance(pass_, VllmInductorPass)
self.passes.append(pass_)
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
if self.npugraph_ex_config.get("fuse_norm_quant", True):
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
self.passes.append(GraphEXAddRMSNormFusionPass(config))
if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
self.passes.append(GraphEXQKNormRopeFusionPass(config))
if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))