[Feature]: implement the fusion of allreduce and matmul in prefill phase when tp is enabled (#1926)
### What this PR does / why we need it?
it'll execute allreduce and malmul seperately in vllm RowParallelLinear
forward funcion, this function use torch_npu.npu_mm_all_reduce_base to
execute allreduce and matmul in a fused kernel way. this will gain a 20%
performance
promotion in eager mode.
### Does this PR introduce _any_ user-facing change?
this PR introduce a new env `VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE` to
control whether enable the feature or not.
### How was this patch tested?
the patch is tested by adding a new test file `test_patch_linear.py` to
guard the ut
- vLLM version: v0.10.0
- vLLM main:
7728dd77bb
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -19,10 +19,6 @@ import pytest
|
|||||||
|
|
||||||
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
||||||
|
|
||||||
# fused moe ops test will hit the infer_schema error, we need add the patch
|
|
||||||
# here to make the test pass.
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class TestBase(unittest.TestCase):
|
class TestBase(unittest.TestCase):
|
||||||
|
|
||||||
|
|||||||
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from tests.ut.base import PytestBase
|
||||||
|
from vllm_ascend import envs
|
||||||
|
from vllm_ascend.patch.worker.patch_common import patch_linear
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendRowParallelLinear(PytestBase):
|
||||||
|
|
||||||
|
def init_row_parallel_linear(self, mocker: MockerFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mocker.patch("torch.nn.Module.__setattr__")
|
||||||
|
mocker.patch("torch.nn.Module.__getattr__")
|
||||||
|
mocker.patch("torch.nn.Module.__delattr__")
|
||||||
|
return patch_linear.AscendRowParallelLinear(
|
||||||
|
input_size=128,
|
||||||
|
output_size=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"version, expected",
|
||||||
|
[
|
||||||
|
("1.0.0", 1),
|
||||||
|
("2.1.0", 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
|
||||||
|
mock_group = mocker.MagicMock()
|
||||||
|
backend = mocker.MagicMock()
|
||||||
|
backend.get_hccl_comm_name = lambda x: x
|
||||||
|
mock_group._get_backend = lambda x: backend
|
||||||
|
mock_group.get_hccl_comm_name = lambda x: x
|
||||||
|
mocker.patch("torch.distributed.get_rank", return_value=1)
|
||||||
|
mocker.patch(
|
||||||
|
"torch.distributed.get_global_rank",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch("torch.__version__", new=version)
|
||||||
|
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
|
||||||
|
mock_group)
|
||||||
|
assert hcomm_info == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"skip_bias_add, return_bias, bias, expected",
|
||||||
|
[
|
||||||
|
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
|
||||||
|
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
torch.tensor(1.0),
|
||||||
|
(torch.tensor(14.0), torch.tensor(1.0)),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_forward(
|
||||||
|
self,
|
||||||
|
skip_bias_add,
|
||||||
|
return_bias,
|
||||||
|
bias,
|
||||||
|
expected,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
mocker_tp_group = mocker.MagicMock()
|
||||||
|
mocker_tp_group.device_group = mocker.MagicMock()
|
||||||
|
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||||
|
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||||
|
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
|
||||||
|
row_parallel_linear.__dict__["return_bias"] = return_bias
|
||||||
|
row_parallel_linear.__dict__["bias"] = bias
|
||||||
|
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
|
||||||
|
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
|
||||||
|
row_parallel_linear.__dict__[
|
||||||
|
"calc_output"] = lambda x: x.matmul( # noqa
|
||||||
|
torch.tensor([1.0, 2.0]))
|
||||||
|
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
|
||||||
|
if isinstance(ret, tuple):
|
||||||
|
assert torch.allclose(ret[0], expected[0])
|
||||||
|
if ret[1] is None:
|
||||||
|
assert ret[1] == expected[1]
|
||||||
|
else:
|
||||||
|
assert torch.allclose(ret[1], expected[1])
|
||||||
|
else:
|
||||||
|
assert torch.allclose(ret, expected)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_is_parallel, expected",
|
||||||
|
[
|
||||||
|
(True, torch.tensor([10.0, 2.0])),
|
||||||
|
(False, torch.tensor([10.0])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_input(
|
||||||
|
self,
|
||||||
|
input_is_parallel,
|
||||||
|
expected,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||||
|
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
|
||||||
|
input_tensor = torch.Tensor([10, 2])
|
||||||
|
mocker.patch(
|
||||||
|
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
|
||||||
|
return_value=[torch.Tensor([10]),
|
||||||
|
torch.Tensor([2])],
|
||||||
|
)
|
||||||
|
input_parallel = row_parallel_linear.calc_input(input_tensor)
|
||||||
|
assert torch.allclose(input_parallel, expected)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reduce_results, tp_size, expected",
|
||||||
|
[
|
||||||
|
(True, 2, torch.tensor(56.0)),
|
||||||
|
(True, 1, torch.tensor(14.0)),
|
||||||
|
(False, 2, torch.tensor(14.0)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_output(
|
||||||
|
self,
|
||||||
|
reduce_results,
|
||||||
|
tp_size,
|
||||||
|
expected,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
quant_method = mocker.MagicMock()
|
||||||
|
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
|
||||||
|
torch.tensor([1.0, 2.0]))
|
||||||
|
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||||
|
row_parallel_linear.__dict__["reduce_results"] = reduce_results
|
||||||
|
row_parallel_linear.__dict__["tp_size"] = tp_size
|
||||||
|
row_parallel_linear.__dict__["quant_method"] = quant_method
|
||||||
|
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||||
|
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
|
||||||
|
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"torch_npu.npu_mm_all_reduce_base",
|
||||||
|
side_effect=lambda input_, weight, hccl_info, bias: input_.
|
||||||
|
matmul( # noqa
|
||||||
|
torch.tensor([4.0, 8.0])),
|
||||||
|
) # noqa
|
||||||
|
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
|
||||||
|
assert torch.allclose(ret, expected)
|
||||||
|
|
||||||
|
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
|
||||||
|
mocker.patch.object(envs,
|
||||||
|
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
|
||||||
|
new=True)
|
||||||
|
reload(patch_linear)
|
||||||
|
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||||
|
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
|
||||||
|
patch_linear.AscendRowParallelLinear)
|
||||||
@@ -154,7 +154,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
|
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
|
||||||
# and the mla_pa will be the default path of deepseek decode path.
|
# and the mla_pa will be the default path of deepseek decode path.
|
||||||
"VLLM_ASCEND_MLA_PA":
|
"VLLM_ASCEND_MLA_PA":
|
||||||
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
|
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
|
||||||
|
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
||||||
|
# this feature is supported in A2, and eager mode will get better performance.
|
||||||
|
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -114,3 +114,19 @@
|
|||||||
# - https://github.com/vllm-project/vllm/pull/21591
|
# - https://github.com/vllm-project/vllm/pull/21591
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Revert it when vLLM merge #21591 and release new version
|
# Revert it when vLLM merge #21591 and release new version
|
||||||
|
# ** File: worker/patch_common/patch_linear.py **
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
|
||||||
|
# Why:
|
||||||
|
# We need to fuse matmul and allreuce in `RowParallelLinear`
|
||||||
|
# to improve performance.
|
||||||
|
# How:
|
||||||
|
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
|
||||||
|
# In this class, we override the `forward` method to use
|
||||||
|
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
|
||||||
|
# Related PR (if no, explain why):
|
||||||
|
# - https://github.com/vllm-project/vllm-ascend/pull/1926
|
||||||
|
# Future Plan:
|
||||||
|
# Validate more models in all kinds of scenario,
|
||||||
|
# if performance is always improved, we can enable this patch by default and remove the env
|
||||||
|
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
|
||||||
|
|||||||
@@ -19,5 +19,6 @@
|
|||||||
# patch files.
|
# patch files.
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
||||||
|
|||||||
145
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
145
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
import vllm
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
split_tensor_along_last_dim)
|
||||||
|
from vllm.distributed.parallel_state import get_tp_group
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
|
from vllm_ascend import envs
|
||||||
|
|
||||||
|
_HCOMM_INFO = None
|
||||||
|
|
||||||
|
|
||||||
|
class AscendRowParallelLinear(RowParallelLinear):
|
||||||
|
"""
|
||||||
|
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||||
|
that overrides the forward method to handle Ascend-specific operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Initialize the AscendRowParallelLinear layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
"""
|
||||||
|
tp_group = get_tp_group().device_group
|
||||||
|
hcomm_info = self.get_hcomm_info(tp_group)
|
||||||
|
self.hcomm_info = hcomm_info
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.weight_t = self.weight.t()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||||
|
"""Get the HCCL communication information for the given group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The HCCL communication name for the given group.
|
||||||
|
"""
|
||||||
|
global _HCOMM_INFO
|
||||||
|
if _HCOMM_INFO is not None:
|
||||||
|
return _HCOMM_INFO
|
||||||
|
|
||||||
|
rank = torch.distributed.get_rank(group)
|
||||||
|
if torch.__version__ > "2.0":
|
||||||
|
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||||
|
_HCOMM_INFO = group._get_backend(
|
||||||
|
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||||
|
|
||||||
|
else:
|
||||||
|
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||||
|
return _HCOMM_INFO
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
"""Forward pass for the AscendRowParallelLinear layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (torch.Tensor): the input tensor to the layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
The output tensor after applying the linear transformation,
|
||||||
|
and optionally the bias if `return_bias` is True.
|
||||||
|
"""
|
||||||
|
input_parallel = self.calc_input(input_)
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in TP>1 case)
|
||||||
|
output = self.calc_output(input_parallel)
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the input tensor for parallel processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (torch.Tensor): the input tensor to be processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The input tensor split along the last dimension
|
||||||
|
for tensor model parallelism, or the original input if not parallel.
|
||||||
|
"""
|
||||||
|
if self.input_is_parallel:
|
||||||
|
return input_
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.tp_size)
|
||||||
|
return splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
|
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the output tensor of forward by considering
|
||||||
|
fusing communication and computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the output tensor after applying the linear transformation
|
||||||
|
and optionally handle communication between tensor model parallel ranks.
|
||||||
|
"""
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||||
|
self.weight_t,
|
||||||
|
self.hcomm_info,
|
||||||
|
bias=bias_)
|
||||||
|
else:
|
||||||
|
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||||
|
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||||
Reference in New Issue
Block a user