diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index 020196d..2a03744 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -36,7 +36,7 @@ MODELS = [ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) -def test_models( +def test_models_with_aclgraph( model: str, max_tokens: int, ) -> None: @@ -48,12 +48,12 @@ def test_models( sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) # TODO: change to use vllmrunner when the registry of custom op is solved # while running pytest - vllm_model = LLM(model) + vllm_model = LLM(model, max_model_len=1024) vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params) del vllm_model torch.npu.empty_cache() - vllm_model = LLM(model, enforce_eager=True) + vllm_model = LLM(model, enforce_eager=True, max_model_len=1024) vllm_eager_outputs = vllm_model.generate(prompts, sampling_params) del vllm_model torch.npu.empty_cache() diff --git a/tests/ut/base.py b/tests/ut/base.py index e34f175..8b396d6 100644 --- a/tests/ut/base.py +++ b/tests/ut/base.py @@ -15,7 +15,7 @@ import unittest -from vllm_ascend.utils import adapt_patch +from vllm_ascend.utils import adapt_patch, register_ascend_customop # fused moe ops test will hit the infer_schema error, we need add the patch # here to make the test pass. @@ -28,4 +28,5 @@ class TestBase(unittest.TestCase): # adapt patch by default. adapt_patch(True) adapt_patch() + register_ascend_customop() super().setUp() diff --git a/tests/ut/conftest.py b/tests/ut/conftest.py new file mode 100644 index 0000000..799edc6 --- /dev/null +++ b/tests/ut/conftest.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from vllm_ascend.utils import adapt_patch # noqa E402 +from vllm_ascend.utils import register_ascend_customop + +adapt_patch() +adapt_patch(True) + +# register Ascend CustomOp here because uts will use this +register_ascend_customop() diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py new file mode 100644 index 0000000..b90ccff --- /dev/null +++ b/tests/ut/ops/test_activation.py @@ -0,0 +1,61 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from unittest.mock import patch + +import pytest +import torch +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul + + +@pytest.fixture +def dummy_tensor(): + return torch.randn(4, 8, dtype=torch.float16) + + +@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1) +def test_QuickGELU_forward(mock_gelu, dummy_tensor): + layer = QuickGELU() + out = layer.forward(dummy_tensor) + + expected_out = dummy_tensor + 1 + assert torch.allclose(out, expected_out) + + mock_gelu.assert_called_once() + + +@pytest.mark.parametrize("is_310p_return", [True, False]) +@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) +def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): + + with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): + layer = SiluAndMul() + out = layer.forward(dummy_tensor) + + if is_310p_return: + expected_arg = dummy_tensor.to(torch.float32) + else: + expected_arg = dummy_tensor + + # assert mock_swiglu.call_count == 1 + mock_swiglu.assert_called_once() + + actual_arg = mock_swiglu.call_args[0][0] + assert torch.allclose( + actual_arg, + expected_arg), "npu_swiglu called with unexpected input" + + expected_out = dummy_tensor + 1 + assert torch.allclose(out, expected_out) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 5ddc59d..3902b6d 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -301,6 +301,24 @@ class TestUtils(TestBase): self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), "Delete kv cache bytes cache dir failed") + @mock.patch("vllm.model_executor.custom_op.CustomOp") + @mock.patch("vllm_ascend.ops.activation.AscendQuickGELU") + @mock.patch("vllm_ascend.ops.activation.AscendSiluAndMul") + def test_register_ascend_customop(self, mock_ascend_silu_and_mul, + mock_ascend_quick_gelu, mock_customop): + utils._ASCEND_CUSTOMOP_IS_REIGISTERED = False + + # ascend custom op is not registered + utils.register_ascend_customop() + # should call register_oot twice + self.assertEqual(mock_customop.register_oot.call_count, 2) + self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) + + # ascend custom op is already registered + utils.register_ascend_customop() + # should not register_oot again, thus only called twice in this ut + self.assertEqual(mock_customop.register_oot.call_count, 2) + class TestProfileExecuteDuration(unittest.TestCase): diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 1c32643..26082fe 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -18,25 +18,25 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul -from vllm_ascend.utils import is_310p + +class AscendQuickGELU(QuickGELU): + + def forward_oot(self, x: torch.tensor) -> torch.Tensor: + import torch_npu + + out = torch_npu.npu_fast_gelu(x) + return out -def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: - import torch_npu +class AscendSiluAndMul(SiluAndMul): - if is_310p(): - out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) - else: - out = torch_npu.npu_swiglu(x) - return out + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: + import torch_npu + from vllm_ascend.utils import is_310p -def quick_gelu_forward_oot(self, x: torch.tensor) -> torch.Tensor: - import torch_npu - - out = torch_npu.npu_fast_gelu(x) - return out - - -QuickGELU.forward_oot = quick_gelu_forward_oot -SiluAndMul.forward_oot = silu_and_mul_forward_oot \ No newline at end of file + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + return out diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f13ed49..6ee02cd 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -29,7 +29,7 @@ from vllm.platforms import Platform, PlatformEnum from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, init_ascend_config) from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p, - update_aclgraph_sizes) + register_ascend_customop, update_aclgraph_sizes) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -205,6 +205,9 @@ class NPUPlatform(Platform): ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config + # register Ascend CustomOp + register_ascend_customop() + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 634e13c..aed5772 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -561,3 +561,26 @@ def delete_torchair_cache_file(): torch_air_abs_path = get_torchair_current_work_dir() if os.path.exists(torch_air_abs_path): shutil.rmtree(torch_air_abs_path) + + +_ASCEND_CUSTOMOP_IS_REIGISTERED = False + + +def register_ascend_customop(): + """Register Ascend CustomOP + + NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, + and ensure this will execute after model config is initilazed. + """ + global _ASCEND_CUSTOMOP_IS_REIGISTERED + if _ASCEND_CUSTOMOP_IS_REIGISTERED: + return + from vllm.model_executor.custom_op import CustomOp + + from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul + CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") + CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, + name="SiluAndMul") + + # NOTE: Keep this at last to ensure all custom actions are registered + _ASCEND_CUSTOMOP_IS_REIGISTERED = True