diff --git a/tests/ut/base.py b/tests/ut/base.py index 065e68e..36583a5 100644 --- a/tests/ut/base.py +++ b/tests/ut/base.py @@ -19,10 +19,6 @@ import pytest from vllm_ascend.utils import adapt_patch, register_ascend_customop -# fused moe ops test will hit the infer_schema error, we need add the patch -# here to make the test pass. -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - class TestBase(unittest.TestCase): diff --git a/tests/ut/patch/worker/patch_common/test_patch_linear.py b/tests/ut/patch/worker/patch_common/test_patch_linear.py new file mode 100644 index 0000000..53be010 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_linear.py @@ -0,0 +1,167 @@ +from importlib import reload + +import pytest +import torch +import vllm +from pytest_mock import MockerFixture + +from tests.ut.base import PytestBase +from vllm_ascend import envs +from vllm_ascend.patch.worker.patch_common import patch_linear + + +class TestAscendRowParallelLinear(PytestBase): + + def init_row_parallel_linear(self, mocker: MockerFixture): + mocker.patch( + "vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__", + return_value=None, + ) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + return patch_linear.AscendRowParallelLinear( + input_size=128, + output_size=256, + ) + + @pytest.mark.parametrize( + "version, expected", + [ + ("1.0.0", 1), + ("2.1.0", 1), + ], + ) + def test_get_hcomm_info(self, version, expected, mocker: MockerFixture): + mock_group = mocker.MagicMock() + backend = mocker.MagicMock() + backend.get_hccl_comm_name = lambda x: x + mock_group._get_backend = lambda x: backend + mock_group.get_hccl_comm_name = lambda x: x + mocker.patch("torch.distributed.get_rank", return_value=1) + mocker.patch( + "torch.distributed.get_global_rank", + return_value=0, + ) + mocker.patch("torch.__version__", new=version) + hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info( + mock_group) + assert hcomm_info == expected + + @pytest.mark.parametrize( + "skip_bias_add, return_bias, bias, expected", + [ + (True, False, torch.tensor(1.0), torch.tensor(14.0)), + (False, True, torch.tensor(1.0), (torch.tensor(14.0), None)), + ( + True, + True, + torch.tensor(1.0), + (torch.tensor(14.0), torch.tensor(1.0)), + ), + ], + ) + def test_forward( + self, + skip_bias_add, + return_bias, + bias, + expected, + mocker: MockerFixture, + ): + mocker_tp_group = mocker.MagicMock() + mocker_tp_group.device_group = mocker.MagicMock() + row_parallel_linear = self.init_row_parallel_linear(mocker) + row_parallel_linear.__dict__["tp_rank"] = 0 + row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add + row_parallel_linear.__dict__["return_bias"] = return_bias + row_parallel_linear.__dict__["bias"] = bias + row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock() + row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa + row_parallel_linear.__dict__[ + "calc_output"] = lambda x: x.matmul( # noqa + torch.tensor([1.0, 2.0])) + ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0])) + if isinstance(ret, tuple): + assert torch.allclose(ret[0], expected[0]) + if ret[1] is None: + assert ret[1] == expected[1] + else: + assert torch.allclose(ret[1], expected[1]) + else: + assert torch.allclose(ret, expected) + + @pytest.mark.parametrize( + "input_is_parallel, expected", + [ + (True, torch.tensor([10.0, 2.0])), + (False, torch.tensor([10.0])), + ], + ) + def test_calc_input( + self, + input_is_parallel, + expected, + mocker: MockerFixture, + ): + row_parallel_linear = self.init_row_parallel_linear(mocker) + row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel + input_tensor = torch.Tensor([10, 2]) + mocker.patch( + "vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa + return_value=0, + ) + mocker.patch( + "vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa + return_value=[torch.Tensor([10]), + torch.Tensor([2])], + ) + input_parallel = row_parallel_linear.calc_input(input_tensor) + assert torch.allclose(input_parallel, expected) + + @pytest.mark.parametrize( + "reduce_results, tp_size, expected", + [ + (True, 2, torch.tensor(56.0)), + (True, 1, torch.tensor(14.0)), + (False, 2, torch.tensor(14.0)), + ], + ) + def test_calc_output( + self, + reduce_results, + tp_size, + expected, + mocker: MockerFixture, + ): + quant_method = mocker.MagicMock() + quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa + torch.tensor([1.0, 2.0])) + row_parallel_linear = self.init_row_parallel_linear(mocker) + row_parallel_linear.__dict__["reduce_results"] = reduce_results + row_parallel_linear.__dict__["tp_size"] = tp_size + row_parallel_linear.__dict__["quant_method"] = quant_method + row_parallel_linear.__dict__["tp_rank"] = 0 + row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa + + mocker.patch( + "vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group", + return_value=mocker.MagicMock(device_group=mocker.MagicMock()), + ) + mocker.patch( + "torch_npu.npu_mm_all_reduce_base", + side_effect=lambda input_, weight, hccl_info, bias: input_. + matmul( # noqa + torch.tensor([4.0, 8.0])), + ) # noqa + ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0])) + assert torch.allclose(ret, expected) + + def test_enable_allreduce_matmul(self, mocker: MockerFixture): + mocker.patch.object(envs, + "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", + new=True) + reload(patch_linear) + assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE + assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id( + patch_linear.AscendRowParallelLinear) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 3409de9..eb0223c 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -154,7 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible # and the mla_pa will be the default path of deepseek decode path. "VLLM_ASCEND_MLA_PA": - lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)) + lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)), + # Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled. + # this feature is supported in A2, and eager mode will get better performance. + "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 61d92ea..3446c45 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -114,3 +114,19 @@ # - https://github.com/vllm-project/vllm/pull/21591 # Future Plan: # Revert it when vLLM merge #21591 and release new version +# ** File: worker/patch_common/patch_linear.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.layers.linear.RowParallelLinear` +# Why: +# We need to fuse matmul and allreuce in `RowParallelLinear` +# to improve performance. +# How: +# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`. +# In this class, we override the `forward` method to use +# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce. +# Related PR (if no, explain why): +# - https://github.com/vllm-project/vllm-ascend/pull/1926 +# Future Plan: +# Validate more models in all kinds of scenario, +# if performance is always improved, we can enable this patch by default and remove the env +# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future. diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index c935868..8eebcdf 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -19,5 +19,6 @@ # patch files. import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa +import vllm_ascend.patch.worker.patch_common.patch_linear # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_linear.py b/vllm_ascend/patch/worker/patch_common/patch_linear.py new file mode 100644 index 0000000..f5fbcec --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_linear.py @@ -0,0 +1,145 @@ +""" +Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +This file is a part of the vllm-ascend project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Union + +import torch +import torch_npu +import vllm +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from vllm.distributed import (get_tensor_model_parallel_rank, + split_tensor_along_last_dim) +from vllm.distributed.parallel_state import get_tp_group +from vllm.model_executor.layers.linear import RowParallelLinear + +from vllm_ascend import envs + +_HCOMM_INFO = None + + +class AscendRowParallelLinear(RowParallelLinear): + """ + AscendRowParallelLinear is a custom implementation of RowParallelLinear + that overrides the forward method to handle Ascend-specific operations. + """ + + def __init__(self, *args, **kwargs): + """Initialize the AscendRowParallelLinear layer. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + tp_group = get_tp_group().device_group + hcomm_info = self.get_hcomm_info(tp_group) + self.hcomm_info = hcomm_info + super().__init__(*args, **kwargs) + self.weight_t = self.weight.t() + + @staticmethod + def get_hcomm_info(group: ProcessGroup) -> str: + """Get the HCCL communication information for the given group. + + Args: + group (ProcessGroup): The process group for which to get the HCCL communication info. + + Returns: + str: The HCCL communication name for the given group. + """ + global _HCOMM_INFO + if _HCOMM_INFO is not None: + return _HCOMM_INFO + + rank = torch.distributed.get_rank(group) + if torch.__version__ > "2.0": + global_rank = torch.distributed.get_global_rank(group, rank) + _HCOMM_INFO = group._get_backend( + torch.device("npu")).get_hccl_comm_name(global_rank) + + else: + _HCOMM_INFO = group.get_hccl_comm_name(rank) + return _HCOMM_INFO + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Forward pass for the AscendRowParallelLinear layer. + + Args: + input_ (torch.Tensor): the input tensor to the layer. + + Returns: + Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + The output tensor after applying the linear transformation, + and optionally the bias if `return_bias` is True. + """ + input_parallel = self.calc_input(input_) + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + output = self.calc_output(input_parallel) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + def calc_input(self, input_: torch.Tensor) -> torch.Tensor: + """Calculate the input tensor for parallel processing. + + Args: + input_ (torch.Tensor): the input tensor to be processed. + + Returns: + torch.Tensor: The input tensor split along the last dimension + for tensor model parallelism, or the original input if not parallel. + """ + if self.input_is_parallel: + return input_ + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + return splitted_input[tp_rank].contiguous() + + def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor: + """Calculate the output tensor of forward by considering + fusing communication and computation. + + Args: + input_parallel (_type_): the input tensor to be processed in parallel. + + Returns: + torch.Tensor: the output tensor after applying the linear transformation + and optionally handle communication between tensor model parallel ranks. + """ + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + if self.reduce_results and self.tp_size > 1: + output = torch_npu.npu_mm_all_reduce_base(input_parallel, + self.weight_t, + self.hcomm_info, + bias=bias_) + else: + output = self.quant_method.apply(self, input_parallel, bias=bias_) + return output + + +if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE: + vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear