diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 5a5f12c5..b5e54ed3 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -96,6 +96,9 @@ The details of each configuration option are as follows: |------------------------| ---- |---------|----------------------------------------------------------------------------------------| | `enable` | bool | `False` | Whether to enable npugraph_ex backend. | | `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. | +| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | +| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | +| `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. | ### Example diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 1d119dbf..83e92e40 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -235,7 +235,15 @@ class NpugraphExConfig: These configurations can directly impact the performance and behavior of models deployed on Ascend platforms. """ - def __init__(self, enable: bool = False, enable_static_kernel: bool = False, **kwargs): + def __init__( + self, + enable: bool = False, + enable_static_kernel: bool = False, + fuse_norm_quant: bool = True, + fuse_qknorm_rope: bool = True, + fuse_allreduce_rms: bool = False, + **kwargs, + ): """ Initialize the configuration. @@ -251,10 +259,20 @@ class NpugraphExConfig: binary files with the corresponding shapes based on the current batch_size, which usually takes some time. Default: False + fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. + When set to True, the system will optimize norm and quant operations. + Default: True + fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. + Default: True + fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization. + Default: False **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.enable = enable self.enable_static_kernel = enable_static_kernel + self.fuse_norm_quant = fuse_norm_quant + self.fuse_qknorm_rope = fuse_qknorm_rope + self.fuse_allreduce_rms = fuse_allreduce_rms class XliteGraphConfig: diff --git a/vllm_ascend/compilation/npu_graph_ex_pass_manager.py b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py index be810de0..15e88b41 100644 --- a/vllm_ascend/compilation/npu_graph_ex_pass_manager.py +++ b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py @@ -48,4 +48,18 @@ class NpuGraphEXPassManager: def configure(self, config: VllmConfig): # By default, we enable the graph fusion and quantization fusion pass. - self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {}) + self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {}) + if self.npugraph_ex_config.get("fuse_norm_quant", True): + from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass + + self.passes.append(GraphEXAddRMSNormFusionPass(config)) + + if self.npugraph_ex_config.get("fuse_qknorm_rope", True): + from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass + + self.passes.append(GraphEXQKNormRopeFusionPass(config)) + + if self.npugraph_ex_config.get("fuse_allreduce_rms", True): + from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass + + self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config)) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py new file mode 100644 index 00000000..94a08389 --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +import torchair +from vllm.config import VllmConfig +from vllm.config.compilation import Range +from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import get_tp_group + +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check + +# computation-communication tiling block is 512 +ALLREDUCE_NORM_FUSE_THREHOLD = 512 + + +class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern: + """ + recognizing the Matmul + AllReduce + AddRMSNorm computation pattern + AllReduce is optimized in the fusion operator to a two-stage communication of ReduceScatter+AllGather + """ + + def __init__(self, vllm_config, eps=1e-6): + self.vllm_config = vllm_config + self.eps = eps + device_group = get_tp_group().device_group + backend = device_group._get_backend(torch.device("npu")) + self.local_rank = torch.distributed.get_rank(group=device_group) + self.tp_group_name = backend.get_hccl_comm_name(self.local_rank) + self.tp_size = get_tensor_model_parallel_world_size() + + def get_inputs(self): + batch_size, seq_len = 2, 4 + hidden_size = 4096 + x = torch.randn(batch_size, seq_len, hidden_size, device="npu") + weight = torch.randn(hidden_size, hidden_size, device="npu") + residual = torch.randn(batch_size, seq_len, hidden_size, device="npu") + rms_norm_weight = torch.randn(hidden_size, device="npu") + return [x, weight, residual, rms_norm_weight] + + def register(self): + def pattern(x, weight, residual, rms_norm_weight): + mm = torch.ops.vllm.unquantized_gemm(x, weight, None) + all_reduce_ = tensor_model_parallel_all_reduce(mm) + output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None) + out0 = output[0] + out1 = output[2] + + return out0, out1 + + def replacement(x, weight, residual, rms_norm_weight): + out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm( + x, + weight, + residual, + rms_norm_weight, + self.tp_group_name, + self.tp_size, + self.local_rank, + self.eps, + True, + False, + ) + return out0, out1 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern: + def __init__(self, vllm_config, eps=1e-6): + self.vllm_config = vllm_config + self.eps = eps + device_group = get_tp_group().device_group + backend = device_group._get_backend(torch.device("npu")) + self.local_rank = torch.distributed.get_rank(group=device_group) + self.tp_group_name = backend.get_hccl_comm_name(self.local_rank) + self.tp_size = get_tensor_model_parallel_world_size() + + def get_inputs(self): + batch_size, seq_len = 2, 4 + hidden_size = 4096 + x = torch.randn(batch_size, seq_len, hidden_size, device="npu") + weight = torch.randn(hidden_size, hidden_size, device="npu") + residual = torch.randn(batch_size, seq_len, hidden_size, device="npu") + rms_norm_weight = torch.randn(hidden_size, device="npu") + return [x, weight, residual, rms_norm_weight] + + def register(self): + def pattern(x, weight, residual, rms_norm_weight): + mm = torch.ops.vllm.unquantized_gemm(x, weight, None) + all_reduce_ = tensor_model_parallel_all_reduce(mm) + output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None) + + return output[0] + + def replacement(x, weight, residual, rms_norm_weight): + out0, _ = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm( + x, + weight, + residual, + rms_norm_weight, + self.tp_group_name, + self.tp_size, + self.local_rank, + self.eps, + True, + False, + ) + return out0 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXMatmulAllReduceAddRMSNormPass: + def __init__(self, vllm_config: VllmConfig): + GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() + GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() + + def __call__(self, graph: torch.fx.Graph): + pass + + def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ + applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD + return applicable diff --git a/vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py index ee53547b..006d329b 100644 --- a/vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py +++ b/vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py @@ -56,7 +56,7 @@ class MiddleLayerMatmulAllReduceAddRMSNormPattern: def pattern(x, weight, residual, rms_norm_weight): mm = torch.ops.vllm.unquantized_gemm(x, weight, None) all_reduce_ = tensor_model_parallel_all_reduce(mm) - output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight) + output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None) out0 = output[0] out1 = output[2] @@ -103,7 +103,7 @@ class LastLayerMatmulAllReduceAddRMSNormPattern: def pattern(x, weight, residual, rms_norm_weight): mm = torch.ops.vllm.unquantized_gemm(x, weight, None) all_reduce_ = tensor_model_parallel_all_reduce(mm) - output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight) + output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None) return output[0]