[Feature]: implement the fusion of allreduce and matmul in prefill phase when tp is enabled (#1926)

### What this PR does / why we need it?
it'll execute allreduce and malmul seperately in vllm RowParallelLinear
forward funcion, this function use torch_npu.npu_mm_all_reduce_base to
execute allreduce and matmul in a fused kernel way. this will gain a 20%
performance
promotion in eager mode.
### Does this PR introduce _any_ user-facing change?
this PR introduce a new env `VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE` to
control whether enable the feature or not.

### How was this patch tested?
the patch is tested by adding a new test file `test_patch_linear.py` to
guard the ut


- vLLM version: v0.10.0
- vLLM main:
7728dd77bb

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald1995
2025-07-28 15:13:37 +08:00
committed by GitHub
parent ba3dfbd59e
commit 32a9c5f694
6 changed files with 334 additions and 5 deletions

View File

@@ -19,10 +19,6 @@ import pytest
from vllm_ascend.utils import adapt_patch, register_ascend_customop
# fused moe ops test will hit the infer_schema error, we need add the patch
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
class TestBase(unittest.TestCase):

View File

@@ -0,0 +1,167 @@
from importlib import reload
import pytest
import torch
import vllm
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend import envs
from vllm_ascend.patch.worker.patch_common import patch_linear
class TestAscendRowParallelLinear(PytestBase):
def init_row_parallel_linear(self, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
return patch_linear.AscendRowParallelLinear(
input_size=128,
output_size=256,
)
@pytest.mark.parametrize(
"version, expected",
[
("1.0.0", 1),
("2.1.0", 1),
],
)
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
mock_group = mocker.MagicMock()
backend = mocker.MagicMock()
backend.get_hccl_comm_name = lambda x: x
mock_group._get_backend = lambda x: backend
mock_group.get_hccl_comm_name = lambda x: x
mocker.patch("torch.distributed.get_rank", return_value=1)
mocker.patch(
"torch.distributed.get_global_rank",
return_value=0,
)
mocker.patch("torch.__version__", new=version)
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
mock_group)
assert hcomm_info == expected
@pytest.mark.parametrize(
"skip_bias_add, return_bias, bias, expected",
[
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
(
True,
True,
torch.tensor(1.0),
(torch.tensor(14.0), torch.tensor(1.0)),
),
],
)
def test_forward(
self,
skip_bias_add,
return_bias,
bias,
expected,
mocker: MockerFixture,
):
mocker_tp_group = mocker.MagicMock()
mocker_tp_group.device_group = mocker.MagicMock()
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
row_parallel_linear.__dict__["return_bias"] = return_bias
row_parallel_linear.__dict__["bias"] = bias
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
row_parallel_linear.__dict__[
"calc_output"] = lambda x: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
if isinstance(ret, tuple):
assert torch.allclose(ret[0], expected[0])
if ret[1] is None:
assert ret[1] == expected[1]
else:
assert torch.allclose(ret[1], expected[1])
else:
assert torch.allclose(ret, expected)
@pytest.mark.parametrize(
"input_is_parallel, expected",
[
(True, torch.tensor([10.0, 2.0])),
(False, torch.tensor([10.0])),
],
)
def test_calc_input(
self,
input_is_parallel,
expected,
mocker: MockerFixture,
):
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
input_tensor = torch.Tensor([10, 2])
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
return_value=0,
)
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
return_value=[torch.Tensor([10]),
torch.Tensor([2])],
)
input_parallel = row_parallel_linear.calc_input(input_tensor)
assert torch.allclose(input_parallel, expected)
@pytest.mark.parametrize(
"reduce_results, tp_size, expected",
[
(True, 2, torch.tensor(56.0)),
(True, 1, torch.tensor(14.0)),
(False, 2, torch.tensor(14.0)),
],
)
def test_calc_output(
self,
reduce_results,
tp_size,
expected,
mocker: MockerFixture,
):
quant_method = mocker.MagicMock()
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["reduce_results"] = reduce_results
row_parallel_linear.__dict__["tp_size"] = tp_size
row_parallel_linear.__dict__["quant_method"] = quant_method
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
)
mocker.patch(
"torch_npu.npu_mm_all_reduce_base",
side_effect=lambda input_, weight, hccl_info, bias: input_.
matmul( # noqa
torch.tensor([4.0, 8.0])),
) # noqa
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
assert torch.allclose(ret, expected)
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
mocker.patch.object(envs,
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
new=True)
reload(patch_linear)
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
patch_linear.AscendRowParallelLinear)

View File

@@ -154,7 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_ASCEND_MLA_PA":
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
}
# end-env-vars-definition

View File

@@ -114,3 +114,19 @@
# - https://github.com/vllm-project/vllm/pull/21591
# Future Plan:
# Revert it when vLLM merge #21591 and release new version
# ** File: worker/patch_common/patch_linear.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
# Why:
# We need to fuse matmul and allreuce in `RowParallelLinear`
# to improve performance.
# How
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
# In this class, we override the `forward` method to use
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/1926
# Future Plan:
# Validate more models in all kinds of scenario,
# if performance is always improved, we can enable this patch by default and remove the env
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.

View File

@@ -19,5 +19,6 @@
# patch files.
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa

View File

@@ -0,0 +1,145 @@
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Union
import torch
import torch_npu
import vllm
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim)
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm_ascend import envs
_HCOMM_INFO = None
class AscendRowParallelLinear(RowParallelLinear):
"""
AscendRowParallelLinear is a custom implementation of RowParallelLinear
that overrides the forward method to handle Ascend-specific operations.
"""
def __init__(self, *args, **kwargs):
"""Initialize the AscendRowParallelLinear layer.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
tp_group = get_tp_group().device_group
hcomm_info = self.get_hcomm_info(tp_group)
self.hcomm_info = hcomm_info
super().__init__(*args, **kwargs)
self.weight_t = self.weight.t()
@staticmethod
def get_hcomm_info(group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group.
Args:
group (ProcessGroup): The process group for which to get the HCCL communication info.
Returns:
str: The HCCL communication name for the given group.
"""
global _HCOMM_INFO
if _HCOMM_INFO is not None:
return _HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
_HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)
else:
_HCOMM_INFO = group.get_hccl_comm_name(rank)
return _HCOMM_INFO
def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Forward pass for the AscendRowParallelLinear layer.
Args:
input_ (torch.Tensor): the input tensor to the layer.
Returns:
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
The output tensor after applying the linear transformation,
and optionally the bias if `return_bias` is True.
"""
input_parallel = self.calc_input(input_)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
output = self.calc_output(input_parallel)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
"""Calculate the input tensor for parallel processing.
Args:
input_ (torch.Tensor): the input tensor to be processed.
Returns:
torch.Tensor: The input tensor split along the last dimension
for tensor model parallelism, or the original input if not parallel.
"""
if self.input_is_parallel:
return input_
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
return splitted_input[tp_rank].contiguous()
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
"""Calculate the output tensor of forward by considering
fusing communication and computation.
Args:
input_parallel (_type_): the input tensor to be processed in parallel.
Returns:
torch.Tensor: the output tensor after applying the linear transformation
and optionally handle communication between tensor model parallel ranks.
"""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
return output
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear