[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)

### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
This commit is contained in:
lidenghui1110
2025-09-07 10:31:32 +08:00
committed by GitHub
parent a58b43b72c
commit 5a7181569c
23 changed files with 576 additions and 807 deletions

View File

@@ -1,363 +1,105 @@
import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
AscendMlpMergedColumnParallelLinear,
AscendMlpRowParallelLinear, LinearBase,
QuantizationConfig)
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
class TestAscendMlpRowParallelLinear(unittest.TestCase):
class BaseLinearTest(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
self.tensor_parallel_world_size = 2
self.tensor_parallel_rank = 0
self.mlp_tensor_parallel_world_size = 2
self.mlp_tensor_parallel_rank = 1
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.get_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
return_value=self.tensor_parallel_world_size)
self.get_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
return_value=self.tensor_parallel_rank)
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
return_value=self.mlp_tensor_parallel_world_size)
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
return_value=self.mlp_tensor_parallel_rank)
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.get_tensor_model_parallel_world_size_mock = \
self.get_tensor_model_parallel_world_size_patch.start()
self.get_tensor_model_parallel_rank_mock = \
self.get_tensor_model_parallel_rank_patch.start()
self.get_mlp_tensor_model_parallel_world_size_mock = \
self.get_mlp_tensor_model_parallel_world_size_patch.start()
self.get_mlp_tensor_model_parallel_rank_mock = \
self.get_mlp_tensor_model_parallel_rank_patch.start()
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.split_tensor_along_last_dim_patch = mock.patch(
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
self.tensor_model_parallel_all_reduce_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_reduce_mock = \
self.tensor_model_parallel_all_reduce_patch.start()
self.split_tensor_along_last_dim_mock = \
self.split_tensor_along_last_dim_patch.start()
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
mock.MagicMock()
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear.get_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
self.get_tensor_model_parallel_world_size_patch.stop()
self.get_tensor_model_parallel_rank_patch.stop()
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
self.get_mlp_tensor_model_parallel_rank_patch.stop()
self.split_tensor_along_last_dim_patch.stop()
self.tensor_model_parallel_all_reduce_patch.stop()
self.get_mlp_tp_group_patch.stop()
for p in self.patches:
p.stop()
def test_init_with_down_proj_prefix(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
prefix="down_proj")
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
self.assertTrue(layer.enable_mlp_optimze)
def test_forward_with_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
layer(input_tensor)
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
self.assertEqual(linear.forward_type, "mlp_tp")
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_forward_without_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
def test_oproj_tp(self):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="other",
input_is_parallel=False,
prefix="o_proj",
)
self.assertEqual(linear.comm_group, parallel_state._OTP)
self.assertEqual(linear.forward_type, "oproj_tp")
input_tensor = torch.randn(16, 8)
layer(input_tensor)
linear(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
def test_skip_bias_add(self):
layer = AscendMlpRowParallelLinear(
class TestAscendColumnParallelLinear(BaseLinearTest):
def test_mlp_tp_init(self):
linear = AscendColumnParallelLinear(
input_size=16,
output_size=8,
skip_bias_add=True,
prefix="down_proj",
)
input_tensor = torch.randn(16, 8)
output, bias = layer(input_tensor)
self.assertIsNotNone(bias)
def test_no_reduce_results(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
def test_input_not_parallel(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
input_is_parallel=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once()
def test_exception_when_reduce_false_and_bias(self):
with self.assertRaises(ValueError):
AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=True,
skip_bias_add=False)
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock distributed functions
self.mlp_tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
self.mlp_tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
self.tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
self.tp_size_mock = self.tp_size_patch.start()
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
self.tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
self.tp_rank_mock = self.tp_rank_patch.start()
self.tp_rank_mock.return_value = 1 # Current GPU rank
# Mock divide function (assumed to be in your module)
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
self.divide_mock = self.divide_patch.start()
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
# Mock QuantizationConfig and QuantMethod
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
# Mock LinearBase initialization
self.linear_base_init_patch = mock.patch.object(
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
self.linear_base_init_patch.start()
self.quant_method_mock = mock.MagicMock()
def mock_linear_base_init(self, instance, *args, **kwargs):
instance.quant_method = self.quant_method_mock
instance.params_dtype = mock.MagicMock()
instance.input_size = 16
instance.output_size = 8
instance.output_size_per_partition = 4
instance.params_dtype = torch.float32
def tearDown(self):
self.mlp_tp_size_patch.stop()
self.mlp_tp_rank_patch.stop()
self.tp_size_patch.stop()
self.tp_rank_patch.stop()
self.divide_patch.stop()
self.linear_base_init_patch.stop()
def test_mlp_optimize_initialization(self):
# Test when prefix contains "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.gate_up_proj",
bias=False,
)
# Verify MLP optimization flags
self.assertTrue(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 2)
self.assertEqual(layer.tp_rank, 0)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_regular_parallel_initialization(self):
# Test when prefix does NOT contain "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.q_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify regular TP flags
self.assertFalse(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 4)
self.assertEqual(layer.tp_rank, 1)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_output_sizes_handling(self):
# Test when output_sizes is provided
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
output_sizes=[4, 4],
prefix="model.layers.0.qkv_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify output_partition_sizes
self.assertEqual(layer.output_partition_sizes, [2])
def test_merged_mlp_tp_init(self):
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
self.assertEqual(linear.forward_type, "mlp_tp")
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
self.mlp_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
self.tensor_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
self.mlp_world_size_patch.start()
self.tensor_world_size_patch.start()
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
self.mlp_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
self.tensor_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
self.mlp_rank_patch.start()
self.tensor_rank_patch.start()
# Mock all_gather methods
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
self.tensor_model_parallel_all_gather_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_gather_mock = \
self.tensor_model_parallel_all_gather_patch.start()
# Mock AscendMlpColumnParallelLinear's __init__
self.linear_init_patch = mock.patch.object(
AscendMlpColumnParallelLinear,
"__init__",
side_effect=self.mock_linear_init)
self.linear_init_patch.start()
# Create mock objects
self.quant_method_mock = mock.MagicMock()
self.apply_output = torch.randn(2, 8)
self.quant_method_mock.apply.return_value = self.apply_output
def mock_linear_init(self, instance, *args, **kwargs):
torch.nn.Module.__init__(instance)
# Set quant_method and other attributes
instance.quant_method = self.quant_method_mock
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
instance.input_size = 16
instance.output_size = 8
instance.gather_output = False
instance.skip_bias_add = False
instance.return_bias = True
def test_forward_with_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.assertEqual(output.shape, self.apply_output.shape)
def test_forward_without_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix not containing "gate_up_proj"
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.quant_method_mock.apply.assert_called_once_with(
layer, input_tensor, layer.bias)
self.tensor_model_parallel_all_gather_mock.assert_not_called()
self.assertEqual(output.shape, self.apply_output.shape)
def tearDown(self):
self.linear_init_patch.stop()
self.mlp_world_size_patch.stop()
self.tensor_world_size_patch.stop()
self.mlp_rank_patch.stop()
self.tensor_rank_patch.stop()
self.get_mlp_tp_group_mock.stop()
self.tensor_model_parallel_all_gather_mock.stop()
if __name__ == '__main__':
unittest.main()