add mlp tp optimze (#2120)
### What this PR does / why we need it?
For dense models, by not applying tensor parallelism (TP) to the
attention module and applying TP to the MLP module, the allreduce
operations in the attention module can be eliminated, thereby reducing
computational overhead. However, this approach increases memory usage,
so the environment variable VLLM_ASCEND_ENABLE_MLP_OPTIMZE is used to
control this optimization.
- vLLM main:
b17109beea
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
363
tests/ut/ops/test_linear.py
Normal file
363
tests/ut/ops/test_linear.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear, LinearBase,
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class TestAscendMlpRowParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
self.tensor_parallel_world_size = 2
|
||||
self.tensor_parallel_rank = 0
|
||||
self.mlp_tensor_parallel_world_size = 2
|
||||
self.mlp_tensor_parallel_rank = 1
|
||||
|
||||
self.get_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
|
||||
return_value=self.tensor_parallel_world_size)
|
||||
self.get_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
|
||||
return_value=self.tensor_parallel_rank)
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
|
||||
return_value=self.mlp_tensor_parallel_world_size)
|
||||
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
|
||||
return_value=self.mlp_tensor_parallel_rank)
|
||||
|
||||
self.get_tensor_model_parallel_world_size_mock = \
|
||||
self.get_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_tensor_model_parallel_rank_mock = \
|
||||
self.get_tensor_model_parallel_rank_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_world_size_mock = \
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_rank_mock = \
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.start()
|
||||
|
||||
self.split_tensor_along_last_dim_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
|
||||
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
|
||||
self.tensor_model_parallel_all_reduce_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_reduce_mock = \
|
||||
self.tensor_model_parallel_all_reduce_patch.start()
|
||||
self.split_tensor_along_last_dim_mock = \
|
||||
self.split_tensor_along_last_dim_patch.start()
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
|
||||
mock.MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.get_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_tensor_model_parallel_rank_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.stop()
|
||||
self.split_tensor_along_last_dim_patch.stop()
|
||||
self.tensor_model_parallel_all_reduce_patch.stop()
|
||||
self.get_mlp_tp_group_patch.stop()
|
||||
|
||||
def test_init_with_down_proj_prefix(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj")
|
||||
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
|
||||
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
|
||||
def test_forward_with_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
|
||||
def test_forward_without_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="other",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
|
||||
|
||||
def test_skip_bias_add(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
self.assertIsNotNone(bias)
|
||||
|
||||
def test_no_reduce_results(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
|
||||
|
||||
def test_input_not_parallel(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
input_is_parallel=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once()
|
||||
|
||||
def test_exception_when_reduce_false_and_bias(self):
|
||||
with self.assertRaises(ValueError):
|
||||
AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=True,
|
||||
skip_bias_add=False)
|
||||
|
||||
|
||||
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock distributed functions
|
||||
self.mlp_tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
|
||||
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
|
||||
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
|
||||
|
||||
self.mlp_tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
|
||||
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
|
||||
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
|
||||
|
||||
self.tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
|
||||
self.tp_size_mock = self.tp_size_patch.start()
|
||||
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
|
||||
|
||||
self.tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
|
||||
self.tp_rank_mock = self.tp_rank_patch.start()
|
||||
self.tp_rank_mock.return_value = 1 # Current GPU rank
|
||||
|
||||
# Mock divide function (assumed to be in your module)
|
||||
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
|
||||
self.divide_mock = self.divide_patch.start()
|
||||
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
|
||||
|
||||
# Mock QuantizationConfig and QuantMethod
|
||||
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
|
||||
|
||||
# Mock LinearBase initialization
|
||||
self.linear_base_init_patch = mock.patch.object(
|
||||
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
|
||||
self.linear_base_init_patch.start()
|
||||
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
|
||||
def mock_linear_base_init(self, instance, *args, **kwargs):
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.params_dtype = mock.MagicMock()
|
||||
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.output_size_per_partition = 4
|
||||
instance.params_dtype = torch.float32
|
||||
|
||||
def tearDown(self):
|
||||
self.mlp_tp_size_patch.stop()
|
||||
self.mlp_tp_rank_patch.stop()
|
||||
self.tp_size_patch.stop()
|
||||
self.tp_rank_patch.stop()
|
||||
self.divide_patch.stop()
|
||||
self.linear_base_init_patch.stop()
|
||||
|
||||
def test_mlp_optimize_initialization(self):
|
||||
# Test when prefix contains "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.gate_up_proj",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify MLP optimization flags
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 2)
|
||||
self.assertEqual(layer.tp_rank, 0)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_regular_parallel_initialization(self):
|
||||
# Test when prefix does NOT contain "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.q_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify regular TP flags
|
||||
self.assertFalse(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 4)
|
||||
self.assertEqual(layer.tp_rank, 1)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_output_sizes_handling(self):
|
||||
# Test when output_sizes is provided
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
output_sizes=[4, 4],
|
||||
prefix="model.layers.0.qkv_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify output_partition_sizes
|
||||
self.assertEqual(layer.output_partition_sizes, [2])
|
||||
|
||||
|
||||
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
|
||||
self.mlp_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
|
||||
self.tensor_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
|
||||
self.mlp_world_size_patch.start()
|
||||
self.tensor_world_size_patch.start()
|
||||
|
||||
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
|
||||
self.mlp_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
|
||||
self.tensor_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
|
||||
self.mlp_rank_patch.start()
|
||||
self.tensor_rank_patch.start()
|
||||
|
||||
# Mock all_gather methods
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
|
||||
self.tensor_model_parallel_all_gather_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_gather_mock = \
|
||||
self.tensor_model_parallel_all_gather_patch.start()
|
||||
|
||||
# Mock AscendMlpColumnParallelLinear's __init__
|
||||
self.linear_init_patch = mock.patch.object(
|
||||
AscendMlpColumnParallelLinear,
|
||||
"__init__",
|
||||
side_effect=self.mock_linear_init)
|
||||
self.linear_init_patch.start()
|
||||
|
||||
# Create mock objects
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
self.apply_output = torch.randn(2, 8)
|
||||
|
||||
self.quant_method_mock.apply.return_value = self.apply_output
|
||||
|
||||
def mock_linear_init(self, instance, *args, **kwargs):
|
||||
torch.nn.Module.__init__(instance)
|
||||
# Set quant_method and other attributes
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.gather_output = False
|
||||
instance.skip_bias_add = False
|
||||
instance.return_bias = True
|
||||
|
||||
def test_forward_with_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def test_forward_without_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix not containing "gate_up_proj"
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.quant_method_mock.apply.assert_called_once_with(
|
||||
layer, input_tensor, layer.bias)
|
||||
self.tensor_model_parallel_all_gather_mock.assert_not_called()
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def tearDown(self):
|
||||
self.linear_init_patch.stop()
|
||||
self.mlp_world_size_patch.stop()
|
||||
self.tensor_world_size_patch.stop()
|
||||
self.mlp_rank_patch.stop()
|
||||
self.tensor_rank_patch.stop()
|
||||
self.get_mlp_tp_group_mock.stop()
|
||||
self.tensor_model_parallel_all_gather_mock.stop()
|
||||
@@ -356,13 +356,13 @@ class TestUtils(TestBase):
|
||||
# ascend custom op is not registered
|
||||
utils.register_ascend_customop()
|
||||
# should call register_oot three
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 3)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 6)
|
||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||
|
||||
# ascend custom op is already registered
|
||||
utils.register_ascend_customop()
|
||||
# should not register_oot again, thus only called three in this ut
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 3)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 6)
|
||||
|
||||
|
||||
class TestProfileExecuteDuration(TestBase):
|
||||
|
||||
@@ -5,8 +5,11 @@ from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
@@ -14,6 +17,11 @@ def get_mc2_group() -> GroupCoordinator:
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
@@ -39,6 +47,33 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
@@ -46,3 +81,8 @@ def destroy_ascend_model_parallel():
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
@@ -141,6 +141,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# 1: enable moe all2all seq.
|
||||
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
|
||||
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
|
||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||
# this feature in eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
309
vllm_ascend/ops/linear.py
Normal file
309
vllm_ascend/ops/linear.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_mlp_tensor_model_parallel_rank,
|
||||
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
|
||||
|
||||
|
||||
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
"""Linear layer with row parallelism.
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
if prefix.find("down_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.enable_mlp_optimze:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
|
||||
# output = output[:num_tokens,:]
|
||||
# dispose_tensor(output_parallel)
|
||||
else:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendMlpColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
if self.enable_mlp_optimze:
|
||||
input2_ = get_mlp_tp_group().all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self, input2_, bias)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
@@ -475,9 +475,20 @@ def register_ascend_customop():
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear)
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
|
||||
name="SiluAndMul")
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
|
||||
name="ColumnParallelLinear")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear,
|
||||
name="RowParallelLinear")
|
||||
CustomOp.register_oot(
|
||||
_decorated_op_cls=AscendMlpMergedColumnParallelLinear,
|
||||
name="MergedColumnParallelLinear")
|
||||
|
||||
from vllm_ascend.ops.layernorm import AscendRMSNorm
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
|
||||
|
||||
Reference in New Issue
Block a user