From f40256b697b9ab533e7e3a9a86a1d2ac8dd7c9b3 Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Fri, 13 Feb 2026 15:40:49 +0800 Subject: [PATCH] [Feat.][310P] addrmsnorm for 300I DUO (#6704) ### What this PR does / why we need it? This PR integrates the `npu_add_rms_norm` fused kernel for RMSNorm operations with residual connections on 310P devices. This change optimizes the computation by replacing a two-step process (manual residual addition followed by RMSNorm) with a single, more efficient fused operation. This is needed to improve the performance of models utilizing RMSNorm with residual connections on the 310P architecture. Fixes # ### Does this PR introduce _any_ user-facing change? No, this PR introduces an internal optimization and does not change any user-facing APIs or behaviors. ### How was this patch tested? This patch was tested with updated unit tests (`test_RMSNorm_forward_310p`) that mock the `npu_add_rms_norm` operation to verify the correctness of the fused kernel integration. --------- Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/ut/ops/test_layernorm.py | 19 +++--- vllm_ascend/_310p/ops/layernorm.py | 10 +-- vllm_ascend/_310p/ops/mm_encoder_attention.py | 61 ------------------- vllm_ascend/utils.py | 2 - 4 files changed, 12 insertions(+), 80 deletions(-) delete mode 100644 vllm_ascend/_310p/ops/mm_encoder_attention.py diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 2a290382..ea112504 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -5,7 +5,7 @@ import torch from vllm.config import set_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm -from vllm_ascend.utils import AscendDeviceType, enable_custom_op +from vllm_ascend.utils import enable_custom_op from vllm_ascend.utils import is_310p as is_310p_hw enable_custom_op() @@ -39,8 +39,8 @@ def default_vllm_config(): with set_current_vllm_config(mock_config): yield mock_config -@pytest.mark.skip( - "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") + +@pytest.mark.skip("Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") @pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.") @pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float32)]) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) @@ -68,19 +68,18 @@ def test_RMSNorm_forward( @pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.") @pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float16)]) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) -def test_RMSNorm_forward_310p( - mock_rmsnorm, residual, dummy_tensor, default_vllm_config -): +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +def test_RMSNorm_forward_310p(mock_add_rmsnorm, mock_rmsnorm, residual, dummy_tensor, default_vllm_config): layer = RMSNorm(hidden_size=8, eps=1e-05) if residual is not None: out_x, out_residual = layer.forward_oot(dummy_tensor, residual) - expected_out_residual = dummy_tensor + residual - expected_out_x = expected_out_residual + 1 - mock_rmsnorm.assert_called_once() + expected_out_x = 2 * dummy_tensor + expected_out_residual = 2 * residual + mock_add_rmsnorm.assert_called_once() assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_residual, expected_out_residual) else: out_x = layer.forward_oot(dummy_tensor, residual) expected_out_x = dummy_tensor + 1 mock_rmsnorm.assert_called_once() - assert torch.allclose(out_x, expected_out_x) \ No newline at end of file + assert torch.allclose(out_x, expected_out_x) diff --git a/vllm_ascend/_310p/ops/layernorm.py b/vllm_ascend/_310p/ops/layernorm.py index f8220d65..fea8b903 100644 --- a/vllm_ascend/_310p/ops/layernorm.py +++ b/vllm_ascend/_310p/ops/layernorm.py @@ -11,13 +11,9 @@ class AscendRMSNorm310(AscendRMSNorm): residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is not None: - if x is None or x.numel() == 0 or x.shape[-1] == 0: - x = residual - else: - x = x + residual - - residual = x - x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) return x, residual x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py deleted file mode 100644 index 7c07a0fd..00000000 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ /dev/null @@ -1,61 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import torch -import torch_npu - -from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention - - -class AscendMMEncoderAttention310(AscendMMEncoderAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward_oot( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - **kwargs, - ): - bsz, q_len = query.size()[:2] - kv_len = key.size(1) - query = query.view(bsz * q_len, self.num_heads, self.head_size) - key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) - value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) - - if cu_seqlens is None: - seq_len = torch.tensor([q_len] * bsz, device="cpu", dtype=torch.int32) - else: - seq_len = torch.diff(cu_seqlens.to("cpu", dtype=torch.int32)) - - output = torch.empty_like(query) - torch_npu._npu_flash_attention_unpad( - query=query, - key=key, - value=value, - seq_len=seq_len, - scale_value=self.head_size**-0.5, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output, - ) - - output = output.view(bsz, -1, self.num_heads, self.head_size) - return output diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2f150160..844f9e16 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -628,13 +628,11 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310 from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 - from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310 REGISTERED_ASCEND_OPS.update( { "SiluAndMul": AscendSiluAndMul310, - "MMEncoderAttention": AscendMMEncoderAttention310, "RotaryEmbedding": AscendRotaryEmbedding310, "RMSNorm": AscendRMSNorm310, "GemmaRMSNorm": AscendGemmaRMSNorm310,