[Feature]: implement the fusion of allreduce and matmul in prefill phase when tp is enabled (#1926)
### What this PR does / why we need it?
it'll execute allreduce and malmul seperately in vllm RowParallelLinear
forward funcion, this function use torch_npu.npu_mm_all_reduce_base to
execute allreduce and matmul in a fused kernel way. this will gain a 20%
performance
promotion in eager mode.
### Does this PR introduce _any_ user-facing change?
this PR introduce a new env `VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE` to
control whether enable the feature or not.
### How was this patch tested?
the patch is tested by adding a new test file `test_patch_linear.py` to
guard the ut
- vLLM version: v0.10.0
- vLLM main:
7728dd77bb
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -19,10 +19,6 @@ import pytest
|
||||
|
||||
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
||||
|
||||
# fused moe ops test will hit the infer_schema error, we need add the patch
|
||||
# here to make the test pass.
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
|
||||
|
||||
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from importlib import reload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import vllm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.patch.worker.patch_common import patch_linear
|
||||
|
||||
|
||||
class TestAscendRowParallelLinear(PytestBase):
|
||||
|
||||
def init_row_parallel_linear(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
return patch_linear.AscendRowParallelLinear(
|
||||
input_size=128,
|
||||
output_size=256,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version, expected",
|
||||
[
|
||||
("1.0.0", 1),
|
||||
("2.1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
|
||||
mock_group = mocker.MagicMock()
|
||||
backend = mocker.MagicMock()
|
||||
backend.get_hccl_comm_name = lambda x: x
|
||||
mock_group._get_backend = lambda x: backend
|
||||
mock_group.get_hccl_comm_name = lambda x: x
|
||||
mocker.patch("torch.distributed.get_rank", return_value=1)
|
||||
mocker.patch(
|
||||
"torch.distributed.get_global_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("torch.__version__", new=version)
|
||||
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
|
||||
mock_group)
|
||||
assert hcomm_info == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_bias_add, return_bias, bias, expected",
|
||||
[
|
||||
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
|
||||
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
|
||||
(
|
||||
True,
|
||||
True,
|
||||
torch.tensor(1.0),
|
||||
(torch.tensor(14.0), torch.tensor(1.0)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_forward(
|
||||
self,
|
||||
skip_bias_add,
|
||||
return_bias,
|
||||
bias,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
mocker_tp_group = mocker.MagicMock()
|
||||
mocker_tp_group.device_group = mocker.MagicMock()
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
|
||||
row_parallel_linear.__dict__["return_bias"] = return_bias
|
||||
row_parallel_linear.__dict__["bias"] = bias
|
||||
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
|
||||
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
|
||||
row_parallel_linear.__dict__[
|
||||
"calc_output"] = lambda x: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
|
||||
if isinstance(ret, tuple):
|
||||
assert torch.allclose(ret[0], expected[0])
|
||||
if ret[1] is None:
|
||||
assert ret[1] == expected[1]
|
||||
else:
|
||||
assert torch.allclose(ret[1], expected[1])
|
||||
else:
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_is_parallel, expected",
|
||||
[
|
||||
(True, torch.tensor([10.0, 2.0])),
|
||||
(False, torch.tensor([10.0])),
|
||||
],
|
||||
)
|
||||
def test_calc_input(
|
||||
self,
|
||||
input_is_parallel,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
|
||||
input_tensor = torch.Tensor([10, 2])
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
|
||||
return_value=[torch.Tensor([10]),
|
||||
torch.Tensor([2])],
|
||||
)
|
||||
input_parallel = row_parallel_linear.calc_input(input_tensor)
|
||||
assert torch.allclose(input_parallel, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reduce_results, tp_size, expected",
|
||||
[
|
||||
(True, 2, torch.tensor(56.0)),
|
||||
(True, 1, torch.tensor(14.0)),
|
||||
(False, 2, torch.tensor(14.0)),
|
||||
],
|
||||
)
|
||||
def test_calc_output(
|
||||
self,
|
||||
reduce_results,
|
||||
tp_size,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
quant_method = mocker.MagicMock()
|
||||
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["reduce_results"] = reduce_results
|
||||
row_parallel_linear.__dict__["tp_size"] = tp_size
|
||||
row_parallel_linear.__dict__["quant_method"] = quant_method
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
|
||||
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch(
|
||||
"torch_npu.npu_mm_all_reduce_base",
|
||||
side_effect=lambda input_, weight, hccl_info, bias: input_.
|
||||
matmul( # noqa
|
||||
torch.tensor([4.0, 8.0])),
|
||||
) # noqa
|
||||
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
|
||||
mocker.patch.object(envs,
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
|
||||
new=True)
|
||||
reload(patch_linear)
|
||||
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
|
||||
patch_linear.AscendRowParallelLinear)
|
||||
@@ -154,7 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
|
||||
# and the mla_pa will be the default path of deepseek decode path.
|
||||
"VLLM_ASCEND_MLA_PA":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
|
||||
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
|
||||
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -114,3 +114,19 @@
|
||||
# - https://github.com/vllm-project/vllm/pull/21591
|
||||
# Future Plan:
|
||||
# Revert it when vLLM merge #21591 and release new version
|
||||
# ** File: worker/patch_common/patch_linear.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
|
||||
# Why:
|
||||
# We need to fuse matmul and allreuce in `RowParallelLinear`
|
||||
# to improve performance.
|
||||
# How:
|
||||
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
|
||||
# In this class, we override the `forward` method to use
|
||||
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/1926
|
||||
# Future Plan:
|
||||
# Validate more models in all kinds of scenario,
|
||||
# if performance is always improved, we can enable this patch by default and remove the env
|
||||
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
|
||||
|
||||
@@ -19,5 +19,6 @@
|
||||
# patch files.
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
||||
|
||||
145
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
145
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
from vllm_ascend import envs
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""
|
||||
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||
that overrides the forward method to handle Ascend-specific operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
tp_group = get_tp_group().device_group
|
||||
hcomm_info = self.get_hcomm_info(tp_group)
|
||||
self.hcomm_info = hcomm_info
|
||||
super().__init__(*args, **kwargs)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||
|
||||
Returns:
|
||||
str: The HCCL communication name for the given group.
|
||||
"""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Forward pass for the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
The output tensor after applying the linear transformation,
|
||||
and optionally the bias if `return_bias` is True.
|
||||
"""
|
||||
input_parallel = self.calc_input(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
output = self.calc_output(input_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the input tensor for parallel processing.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to be processed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor split along the last dimension
|
||||
for tensor model parallelism, or the original input if not parallel.
|
||||
"""
|
||||
if self.input_is_parallel:
|
||||
return input_
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
return splitted_input[tp_rank].contiguous()
|
||||
|
||||
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation.
|
||||
|
||||
Args:
|
||||
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after applying the linear transformation
|
||||
and optionally handle communication between tensor model parallel ranks.
|
||||
"""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
return output
|
||||
|
||||
|
||||
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||
Reference in New Issue
Block a user