diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py new file mode 100644 index 0000000..28b26b7 --- /dev/null +++ b/tests/ut/ops/test_linear.py @@ -0,0 +1,363 @@ +import os +import unittest +from unittest import mock + +import torch + +from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, + AscendMlpMergedColumnParallelLinear, + AscendMlpRowParallelLinear, LinearBase, + QuantizationConfig) + + +class TestAscendMlpRowParallelLinear(unittest.TestCase): + + def setUp(self): + os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + self.tensor_parallel_world_size = 2 + self.tensor_parallel_rank = 0 + self.mlp_tensor_parallel_world_size = 2 + self.mlp_tensor_parallel_rank = 1 + + self.get_tensor_model_parallel_world_size_patch = mock.patch( + 'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size', + return_value=self.tensor_parallel_world_size) + self.get_tensor_model_parallel_rank_patch = mock.patch( + 'vllm_ascend.ops.linear.get_tensor_model_parallel_rank', + return_value=self.tensor_parallel_rank) + self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch( + 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size', + return_value=self.mlp_tensor_parallel_world_size) + self.get_mlp_tensor_model_parallel_rank_patch = mock.patch( + 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank', + return_value=self.mlp_tensor_parallel_rank) + + self.get_tensor_model_parallel_world_size_mock = \ + self.get_tensor_model_parallel_world_size_patch.start() + self.get_tensor_model_parallel_rank_mock = \ + self.get_tensor_model_parallel_rank_patch.start() + self.get_mlp_tensor_model_parallel_world_size_mock = \ + self.get_mlp_tensor_model_parallel_world_size_patch.start() + self.get_mlp_tensor_model_parallel_rank_mock = \ + self.get_mlp_tensor_model_parallel_rank_patch.start() + + self.split_tensor_along_last_dim_patch = mock.patch( + 'vllm_ascend.ops.linear.split_tensor_along_last_dim', + return_value=(torch.randn(10, 8), torch.randn(10, 8))) + self.tensor_model_parallel_all_reduce_patch = mock.patch( + 'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce', + return_value=torch.randn(10, 8)) + self.tensor_model_parallel_all_reduce_mock = \ + self.tensor_model_parallel_all_reduce_patch.start() + self.split_tensor_along_last_dim_mock = \ + self.split_tensor_along_last_dim_patch.start() + self.get_mlp_tp_group_patch = \ + mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group') + self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start() + self.get_mlp_tp_group_mock.return_value = mock.MagicMock() + self.get_mlp_tp_group_mock.return_value.reduce_scatter = \ + mock.MagicMock() + + def tearDown(self): + self.get_tensor_model_parallel_world_size_patch.stop() + self.get_tensor_model_parallel_rank_patch.stop() + self.get_mlp_tensor_model_parallel_world_size_patch.stop() + self.get_mlp_tensor_model_parallel_rank_patch.stop() + self.split_tensor_along_last_dim_patch.stop() + self.tensor_model_parallel_all_reduce_patch.stop() + self.get_mlp_tp_group_patch.stop() + + def test_init_with_down_proj_prefix(self): + layer = AscendMlpRowParallelLinear(input_size=16, + output_size=8, + prefix="down_proj") + self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size) + self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank) + self.assertTrue(layer.enable_mlp_optimze) + + def test_forward_with_mlp_optimize(self): + layer = AscendMlpRowParallelLinear( + input_size=16, + output_size=8, + prefix="down_proj", + input_is_parallel=False, + ) + input_tensor = torch.randn(16, 8) # (batch_size, input_size) + layer(input_tensor) + + self.split_tensor_along_last_dim_mock.assert_called_once_with( + input_tensor, num_partitions=layer.tp_size) + + def test_forward_without_mlp_optimize(self): + layer = AscendMlpRowParallelLinear( + input_size=16, + output_size=8, + prefix="other", + input_is_parallel=False, + ) + input_tensor = torch.randn(16, 8) + layer(input_tensor) + + self.split_tensor_along_last_dim_mock.assert_called_once_with( + input_tensor, num_partitions=layer.tp_size) + self.tensor_model_parallel_all_reduce_mock.assert_called_once() + + def test_skip_bias_add(self): + layer = AscendMlpRowParallelLinear( + input_size=16, + output_size=8, + skip_bias_add=True, + ) + input_tensor = torch.randn(16, 8) + output, bias = layer(input_tensor) + + self.assertIsNotNone(bias) + + def test_no_reduce_results(self): + layer = AscendMlpRowParallelLinear(input_size=16, + output_size=8, + reduce_results=False, + bias=False) + input_tensor = torch.randn(16, 8) + layer(input_tensor) + + self.tensor_model_parallel_all_reduce_mock.assert_not_called() + + def test_input_not_parallel(self): + layer = AscendMlpRowParallelLinear(input_size=16, + output_size=8, + input_is_parallel=False) + input_tensor = torch.randn(16, 8) + layer(input_tensor) + + self.split_tensor_along_last_dim_mock.assert_called_once() + + def test_exception_when_reduce_false_and_bias(self): + with self.assertRaises(ValueError): + AscendMlpRowParallelLinear(input_size=16, + output_size=8, + reduce_results=False, + bias=True, + skip_bias_add=False) + + +class TestAscendMlpColumnParallelLinear(unittest.TestCase): + + def setUp(self): + os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + # Mock distributed functions + self.mlp_tp_size_patch = \ + mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size') + self.mlp_tp_size_mock = self.mlp_tp_size_patch.start() + self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group + + self.mlp_tp_rank_patch = \ + mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank') + self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start() + self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank + + self.tp_size_patch = \ + mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size') + self.tp_size_mock = self.tp_size_patch.start() + self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group + + self.tp_rank_patch = \ + mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank') + self.tp_rank_mock = self.tp_rank_patch.start() + self.tp_rank_mock.return_value = 1 # Current GPU rank + + # Mock divide function (assumed to be in your module) + self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide') + self.divide_mock = self.divide_patch.start() + self.divide_mock.side_effect = lambda x, y: x // y # Simulate division + + # Mock QuantizationConfig and QuantMethod + self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig) + + # Mock LinearBase initialization + self.linear_base_init_patch = mock.patch.object( + LinearBase, "__init__", side_effect=self.mock_linear_base_init) + self.linear_base_init_patch.start() + + self.quant_method_mock = mock.MagicMock() + + def mock_linear_base_init(self, instance, *args, **kwargs): + instance.quant_method = self.quant_method_mock + instance.params_dtype = mock.MagicMock() + + instance.input_size = 16 + instance.output_size = 8 + instance.output_size_per_partition = 4 + instance.params_dtype = torch.float32 + + def tearDown(self): + self.mlp_tp_size_patch.stop() + self.mlp_tp_rank_patch.stop() + self.tp_size_patch.stop() + self.tp_rank_patch.stop() + self.divide_patch.stop() + self.linear_base_init_patch.stop() + + def test_mlp_optimize_initialization(self): + # Test when prefix contains "gate_up_proj" + with mock.patch.object(torch.nn.Module, 'register_parameter'): + layer = AscendMlpColumnParallelLinear( + input_size=16, + output_size=8, + prefix="model.layers.0.gate_up_proj", + bias=False, + ) + + # Verify MLP optimization flags + self.assertTrue(layer.enable_mlp_optimze) + self.assertEqual(layer.tp_size, 2) + self.assertEqual(layer.tp_rank, 0) + self.assertEqual(layer.input_size_per_partition, 16) + self.assertEqual(layer.output_size_per_partition, 4) + + # Check quant_method.create_weights was called + self.quant_method_mock.create_weights.assert_called_once() + + def test_regular_parallel_initialization(self): + # Test when prefix does NOT contain "gate_up_proj" + with mock.patch.object(torch.nn.Module, 'register_parameter'): + layer = AscendMlpColumnParallelLinear( + input_size=16, + output_size=8, + prefix="model.layers.0.q_proj", + quant_config=self.quant_config_mock, + bias=False, + ) + + # Verify regular TP flags + self.assertFalse(layer.enable_mlp_optimze) + self.assertEqual(layer.tp_size, 4) + self.assertEqual(layer.tp_rank, 1) + self.assertEqual(layer.input_size_per_partition, 16) + self.assertEqual(layer.output_size_per_partition, 4) + # Check quant_method.create_weights was called + self.quant_method_mock.create_weights.assert_called_once() + + def test_output_sizes_handling(self): + # Test when output_sizes is provided + with mock.patch.object(torch.nn.Module, 'register_parameter'): + layer = AscendMlpColumnParallelLinear( + input_size=16, + output_size=8, + output_sizes=[4, 4], + prefix="model.layers.0.qkv_proj", + quant_config=self.quant_config_mock, + bias=False, + ) + + # Verify output_partition_sizes + self.assertEqual(layer.output_partition_sizes, [2]) + + +class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase): + + def setUp(self): + os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + # Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size + self.mlp_world_size_patch = \ + mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2) + self.tensor_world_size_patch = \ + mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2) + self.mlp_world_size_patch.start() + self.tensor_world_size_patch.start() + + # Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank + self.mlp_rank_patch = \ + mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0) + self.tensor_rank_patch = \ + mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0) + self.mlp_rank_patch.start() + self.tensor_rank_patch.start() + + # Mock all_gather methods + self.get_mlp_tp_group_patch = \ + mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group') + self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start() + self.get_mlp_tp_group_mock.return_value = mock.MagicMock() + self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock() + self.tensor_model_parallel_all_gather_patch = mock.patch( + 'vllm_ascend.ops.linear.tensor_model_parallel_all_gather', + return_value=torch.randn(10, 8)) + self.tensor_model_parallel_all_gather_mock = \ + self.tensor_model_parallel_all_gather_patch.start() + + # Mock AscendMlpColumnParallelLinear's __init__ + self.linear_init_patch = mock.patch.object( + AscendMlpColumnParallelLinear, + "__init__", + side_effect=self.mock_linear_init) + self.linear_init_patch.start() + + # Create mock objects + self.quant_method_mock = mock.MagicMock() + self.apply_output = torch.randn(2, 8) + + self.quant_method_mock.apply.return_value = self.apply_output + + def mock_linear_init(self, instance, *args, **kwargs): + torch.nn.Module.__init__(instance) + # Set quant_method and other attributes + instance.quant_method = self.quant_method_mock + instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias + instance.input_size = 16 + instance.output_size = 8 + instance.gather_output = False + instance.skip_bias_add = False + instance.return_bias = True + + def test_forward_with_enable_mlp_optimze(self): + # Setup input + input_tensor = torch.randn(1, 16) + + # Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True + layer = AscendMlpMergedColumnParallelLinear(input_size=16, + output_sizes=[8], + bias=True, + gather_output=False, + skip_bias_add=False, + params_dtype=torch.float32, + quant_config=None, + prefix="other_proj") + + # Call forward + output, bias = layer(input_tensor) + + # Validate calls + self.assertEqual(output.shape, self.apply_output.shape) + + def test_forward_without_enable_mlp_optimze(self): + # Setup input + input_tensor = torch.randn(1, 16) + + # Create instance with prefix not containing "gate_up_proj" + layer = AscendMlpMergedColumnParallelLinear(input_size=16, + output_sizes=[8], + bias=True, + gather_output=False, + skip_bias_add=False, + params_dtype=torch.float32, + quant_config=None, + prefix="other_proj") + + # Call forward + output, bias = layer(input_tensor) + + # Validate calls + self.quant_method_mock.apply.assert_called_once_with( + layer, input_tensor, layer.bias) + self.tensor_model_parallel_all_gather_mock.assert_not_called() + self.assertEqual(output.shape, self.apply_output.shape) + + def tearDown(self): + self.linear_init_patch.stop() + self.mlp_world_size_patch.stop() + self.tensor_world_size_patch.stop() + self.mlp_rank_patch.stop() + self.tensor_rank_patch.stop() + self.get_mlp_tp_group_mock.stop() + self.tensor_model_parallel_all_gather_mock.stop() diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index d3c4c90..db4fcc3 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -356,13 +356,13 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 3) + self.assertEqual(mock_customop.register_oot.call_count, 6) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 3) + self.assertEqual(mock_customop.register_oot.call_count, 6) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 1e31359..db1c5a8 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -5,8 +5,11 @@ from vllm.config import ParallelConfig from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, init_model_parallel_group) +import vllm_ascend.envs as envs_ascend + # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None +_MLP_TP: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -14,6 +17,11 @@ def get_mc2_group() -> GroupCoordinator: return _MC2 +def get_mlp_tp_group() -> GroupCoordinator: + assert _MLP_TP is not None, ("mlp group is not initialized") + return _MLP_TP + + def model_parallel_initialized(): return (_MC2 is not None) @@ -39,6 +47,33 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): get_world_group().local_rank, backend, group_name="mc2") + if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: + global _MLP_TP + assert _MLP_TP is None, ( + "mlp tensor model parallel group is already initialized") + + mlp_tp = parallel_config.data_parallel_size + + all_ranks_mlp_head = torch.arange(world_size).reshape( + -1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa + group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + # message queue broadcaster is only used in tensor model parallel group + _MLP_TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="mlp_tp") + + +def get_mlp_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_mlp_tp_group().world_size + + +def get_mlp_tensor_model_parallel_rank(): + """Return world size for the tensor model parallel group.""" + return get_mlp_tp_group().rank_in_group def destroy_ascend_model_parallel(): @@ -46,3 +81,8 @@ def destroy_ascend_model_parallel(): if _MC2: _MC2.destroy() _MC2 = None + + global _MLP_TP + if _MLP_TP: + _MLP_TP.destroy() + _MLP_TP = None diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 16148bb..8469297 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -141,6 +141,10 @@ env_variables: Dict[str, Callable[[], Any]] = { # 1: enable moe all2all seq. "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))), + # Whether to enable mlp optimize when tensor parallel is enabled. + # this feature in eager mode will get better performance. + "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py new file mode 100644 index 0000000..e2f427e --- /dev/null +++ b/vllm_ascend/ops/linear.py @@ -0,0 +1,309 @@ +""" +Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +This file is a part of the vllm-ascend project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Union + +import torch +from torch.nn.parameter import Parameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm.model_executor.utils import set_weight_attrs + +from vllm_ascend.distributed.parallel_state import ( + get_mlp_tensor_model_parallel_rank, + get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group) + + +class AscendMlpColumnParallelLinear(ColumnParallelLinear): + """Linear layer with column parallelism. + + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + if prefix.find("gate_up_proj") != -1: + self.tp_size = get_mlp_tensor_model_parallel_world_size() + self.tp_rank = get_mlp_tensor_model_parallel_rank() + self.enable_mlp_optimze = True + else: + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.enable_mlp_optimze = False + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + LinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + +class AscendMlpRowParallelLinear(RowParallelLinear): + """Linear layer with row parallelism. + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + if prefix.find("down_proj") != -1: + self.tp_size = get_mlp_tensor_model_parallel_world_size() + self.tp_rank = get_mlp_tensor_model_parallel_rank() + self.enable_mlp_optimze = True + else: + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.enable_mlp_optimze = False + # Divide the weight matrix along the first dimension. + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + LinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.enable_mlp_optimze: + tp_rank = get_mlp_tensor_model_parallel_rank() + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_mlp_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 + or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + output = get_mlp_tp_group().reduce_scatter(output_parallel, 0) + # output = output[:num_tokens,:] + # dispose_tensor(output_parallel) + else: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 + or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.output_sizes = output_sizes + if prefix.find("gate_up_proj") != -1: + self.tp_size = get_mlp_tensor_model_parallel_world_size() + self.tp_rank = get_mlp_tensor_model_parallel_rank() + self.enable_mlp_optimze = True + else: + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.enable_mlp_optimze = False + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) + AscendMlpColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + # self.global_batch_size = vllm_config.scheduler_config.max_num_seqs + # Matrix multiply. + assert self.quant_method is not None + if self.enable_mlp_optimze: + input2_ = get_mlp_tp_group().all_gather(input_, 0) + output = self.quant_method.apply(self, input2_, bias) + else: + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a0586a0..9a1647c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -475,9 +475,20 @@ def register_ascend_customop(): from vllm.model_executor.custom_op import CustomOp from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul + from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, + AscendMlpMergedColumnParallelLinear, + AscendMlpRowParallelLinear) CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, name="SiluAndMul") + if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: + CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, + name="ColumnParallelLinear") + CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear, + name="RowParallelLinear") + CustomOp.register_oot( + _decorated_op_cls=AscendMlpMergedColumnParallelLinear, + name="MergedColumnParallelLinear") from vllm_ascend.ops.layernorm import AscendRMSNorm CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")