From 83bd77c9833110bca8b7e37e63e92de480200059 Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Tue, 24 Mar 2026 09:00:11 +0800 Subject: [PATCH] [310p]: add rmsnorm gated fallback and unit test (#7424) ### What this PR does / why we need it? RFC #7394 310P cannot use the fused `rmsnormgated` operator and must fall back to the native implementation. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? ut - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 --------- Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/ut/_310p/ops/test_layernorm_310.py | 41 ++++++++++++++++++++++++ vllm_ascend/_310p/ops/layernorm.py | 12 +++++++ vllm_ascend/utils.py | 7 +++- 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/ut/_310p/ops/test_layernorm_310.py diff --git a/tests/ut/_310p/ops/test_layernorm_310.py b/tests/ut/_310p/ops/test_layernorm_310.py new file mode 100644 index 00000000..92bdf0a7 --- /dev/null +++ b/tests/ut/_310p/ops/test_layernorm_310.py @@ -0,0 +1,41 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from vllm.config import set_current_vllm_config +from vllm.model_executor.layers.layernorm import RMSNormGated + +from vllm_ascend._310p.ops.layernorm import AscendRMSNormGated310 + + +@pytest.fixture(autouse=True) +def default_vllm_config(): + mock_config = MagicMock() + mock_config.compilation_config.custom_ops = ["all"] + with set_current_vllm_config(mock_config): + yield mock_config + + +def test_rmsnorm_gated_310_forward_oot_uses_forward_native(): + layer = AscendRMSNormGated310(hidden_size=8, eps=1e-5) + x = torch.randn(2, 8, dtype=torch.float32) + z = torch.randn(2, 8, dtype=torch.float32) + expected = torch.randn(2, 8, dtype=torch.float32) + + with patch.object(RMSNormGated, "forward_native", autospec=True, return_value=expected) as mock_forward_native: + out = layer.forward_oot(x, z) + + mock_forward_native.assert_called_once_with(layer, x, z) + assert out is expected + + +def test_rmsnorm_gated_310_forward_oot_uses_forward_native_without_gate(): + layer = AscendRMSNormGated310(hidden_size=8, eps=1e-5) + x = torch.randn(2, 8, dtype=torch.float32) + expected = torch.randn(2, 8, dtype=torch.float32) + + with patch.object(RMSNormGated, "forward_native", autospec=True, return_value=expected) as mock_forward_native: + out = layer.forward_oot(x, None) + + mock_forward_native.assert_called_once_with(layer, x, None) + assert out is expected diff --git a/vllm_ascend/_310p/ops/layernorm.py b/vllm_ascend/_310p/ops/layernorm.py index fea8b903..ca0e8c04 100644 --- a/vllm_ascend/_310p/ops/layernorm.py +++ b/vllm_ascend/_310p/ops/layernorm.py @@ -1,5 +1,6 @@ import torch import torch_npu +from vllm.model_executor.layers.layernorm import RMSNormGated from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm @@ -37,3 +38,14 @@ class AscendGemmaRMSNorm310(AscendGemmaRMSNorm): x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) return x + + +class AscendRMSNormGated310(RMSNormGated): + def forward_oot( + self, + x: torch.Tensor, + z: torch.Tensor | None = None, + ) -> torch.Tensor: + # 310P should not depend on the Triton-gated layernorm path. + # Reuse the upstream native implementation directly. + return super().forward_native(x, z) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 841e8562..c1e1db42 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -661,7 +661,11 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): if is_310p(): from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310 from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 - from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 + from vllm_ascend._310p.ops.layernorm import ( + AscendGemmaRMSNorm310, + AscendRMSNorm310, + AscendRMSNormGated310, + ) from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310 from vllm_ascend._310p.ops.vocab_parallel_embedding import ( @@ -675,6 +679,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "RotaryEmbedding": AscendRotaryEmbedding310, "RMSNorm": AscendRMSNorm310, "GemmaRMSNorm": AscendGemmaRMSNorm310, + "RMSNormGated": AscendRMSNormGated310, "FusedMoE": AscendFusedMoE310, "SharedFusedMoE": AscendSharedFusedMoE310, "ParallelLMHead": AscendParallelLMHead310,