diff --git a/tests/ut/_310p/ops/test_layernorm_310.py b/tests/ut/_310p/ops/test_layernorm_310.py new file mode 100644 index 00000000..92bdf0a7 --- /dev/null +++ b/tests/ut/_310p/ops/test_layernorm_310.py @@ -0,0 +1,41 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from vllm.config import set_current_vllm_config +from vllm.model_executor.layers.layernorm import RMSNormGated + +from vllm_ascend._310p.ops.layernorm import AscendRMSNormGated310 + + +@pytest.fixture(autouse=True) +def default_vllm_config(): + mock_config = MagicMock() + mock_config.compilation_config.custom_ops = ["all"] + with set_current_vllm_config(mock_config): + yield mock_config + + +def test_rmsnorm_gated_310_forward_oot_uses_forward_native(): + layer = AscendRMSNormGated310(hidden_size=8, eps=1e-5) + x = torch.randn(2, 8, dtype=torch.float32) + z = torch.randn(2, 8, dtype=torch.float32) + expected = torch.randn(2, 8, dtype=torch.float32) + + with patch.object(RMSNormGated, "forward_native", autospec=True, return_value=expected) as mock_forward_native: + out = layer.forward_oot(x, z) + + mock_forward_native.assert_called_once_with(layer, x, z) + assert out is expected + + +def test_rmsnorm_gated_310_forward_oot_uses_forward_native_without_gate(): + layer = AscendRMSNormGated310(hidden_size=8, eps=1e-5) + x = torch.randn(2, 8, dtype=torch.float32) + expected = torch.randn(2, 8, dtype=torch.float32) + + with patch.object(RMSNormGated, "forward_native", autospec=True, return_value=expected) as mock_forward_native: + out = layer.forward_oot(x, None) + + mock_forward_native.assert_called_once_with(layer, x, None) + assert out is expected diff --git a/vllm_ascend/_310p/ops/layernorm.py b/vllm_ascend/_310p/ops/layernorm.py index fea8b903..ca0e8c04 100644 --- a/vllm_ascend/_310p/ops/layernorm.py +++ b/vllm_ascend/_310p/ops/layernorm.py @@ -1,5 +1,6 @@ import torch import torch_npu +from vllm.model_executor.layers.layernorm import RMSNormGated from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm @@ -37,3 +38,14 @@ class AscendGemmaRMSNorm310(AscendGemmaRMSNorm): x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) return x + + +class AscendRMSNormGated310(RMSNormGated): + def forward_oot( + self, + x: torch.Tensor, + z: torch.Tensor | None = None, + ) -> torch.Tensor: + # 310P should not depend on the Triton-gated layernorm path. + # Reuse the upstream native implementation directly. + return super().forward_native(x, z) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 841e8562..c1e1db42 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -661,7 +661,11 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): if is_310p(): from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310 from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 - from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 + from vllm_ascend._310p.ops.layernorm import ( + AscendGemmaRMSNorm310, + AscendRMSNorm310, + AscendRMSNormGated310, + ) from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310 from vllm_ascend._310p.ops.vocab_parallel_embedding import ( @@ -675,6 +679,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "RotaryEmbedding": AscendRotaryEmbedding310, "RMSNorm": AscendRMSNorm310, "GemmaRMSNorm": AscendGemmaRMSNorm310, + "RMSNormGated": AscendRMSNormGated310, "FusedMoE": AscendFusedMoE310, "SharedFusedMoE": AscendSharedFusedMoE310, "ParallelLMHead": AscendParallelLMHead310,