diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py new file mode 100644 index 00000000..c7bc657f --- /dev/null +++ b/tests/ut/ops/test_layernorm.py @@ -0,0 +1,53 @@ +from unittest.mock import patch + +import pytest +import torch +from vllm.model_executor.layers.layernorm import RMSNorm + + +@pytest.fixture +def dummy_tensor(): + return torch.randn(4, 8, dtype=torch.float16) + + +def mock_rms_norm(x, weight, eps): + return x + 1, None + + +def mock_add_rms_norm(x, residual, weight, eps): + return 2 * x, None, 2 * residual + + +@pytest.mark.parametrize("is_310p_return", [True, False]) +@pytest.mark.parametrize("residual", + [None, torch.randn(4, 8, dtype=torch.float32)]) +@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, + residual, dummy_tensor): + + with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): + layer = RMSNorm(hidden_size=32, eps=1e-05) + if residual is not None: + out_x, out_residual = layer.forward_oot(dummy_tensor, residual) + + if is_310p_return: + expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype) + expected_out_x = expected_arg_x + 1 + expected_out_residual = expected_arg_x.to(residual.dtype) + + mock_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) + else: + expected_out_x = 2 * dummy_tensor + expected_out_residual = 2 * residual + mock_add_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) + else: + out_x = layer.forward(dummy_tensor, residual) + expected_out_x = dummy_tensor + 1 + + mock_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 94d37a02..d3c4c903 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -347,20 +347,22 @@ class TestUtils(TestBase): @mock.patch("vllm.model_executor.custom_op.CustomOp") @mock.patch("vllm_ascend.ops.activation.AscendQuickGELU") @mock.patch("vllm_ascend.ops.activation.AscendSiluAndMul") - def test_register_ascend_customop(self, mock_ascend_silu_and_mul, + @mock.patch("vllm_ascend.ops.layernorm.AscendRMSNorm") + def test_register_ascend_customop(self, mock_ascend_rmsnorm, + mock_ascend_silu_and_mul, mock_ascend_quick_gelu, mock_customop): utils._ASCEND_CUSTOMOP_IS_REIGISTERED = False # ascend custom op is not registered utils.register_ascend_customop() - # should call register_oot twice - self.assertEqual(mock_customop.register_oot.call_count, 2) + # should call register_oot three + self.assertEqual(mock_customop.register_oot.call_count, 3) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() - # should not register_oot again, thus only called twice in this ut - self.assertEqual(mock_customop.register_oot.call_count, 2) + # should not register_oot again, thus only called three in this ut + self.assertEqual(mock_customop.register_oot.call_count, 3) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 7506f87d..4f0b550e 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -20,8 +20,6 @@ from typing import Optional, Tuple, Union import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm_ascend.utils import is_310p - class AddRMSNormW8A8Quant(RMSNorm): # Fuse AddRmsNorm and W8A8 quantization ops together @@ -60,27 +58,28 @@ class AddRMSNormW8A8Quant(RMSNorm): return x -def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - import torch_npu +class AscendRMSNorm(RMSNorm): - if residual is not None: - if is_310p(): - orig_dtype = residual.dtype - x = x + residual.to(x.dtype) - residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) - return x, residual + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - return x + from vllm_ascend.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + return x, residual - -RMSNorm.forward_oot = forward_oot + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + return x diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a3befce1..a0586a0f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -479,6 +479,9 @@ def register_ascend_customop(): CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, name="SiluAndMul") + from vllm_ascend.ops.layernorm import AscendRMSNorm + CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm") + # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True