### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
import os
|
|
import unittest
|
|
from unittest import mock
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from vllm_ascend import ascend_config
|
|
from vllm_ascend.distributed import parallel_state
|
|
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
|
AscendMergedColumnParallelLinear,
|
|
AscendRowParallelLinear)
|
|
|
|
|
|
class BaseLinearTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.mock_group = mock.MagicMock()
|
|
self.mock_group.world_size = 2
|
|
self.mock_group.rank_in_group = 0
|
|
|
|
parallel_state._MLP_TP = self.mock_group
|
|
parallel_state._OTP = self.mock_group
|
|
|
|
self.mock_ascend_config = MagicMock()
|
|
self.mock_ascend_config.oproj_tensor_parallel_size = 2
|
|
|
|
self.patches = [
|
|
patch("vllm_ascend.ascend_config.get_ascend_config",
|
|
return_value=self.mock_ascend_config),
|
|
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
|
|
return_value=self.mock_group),
|
|
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
|
|
return_value=self.mock_group),
|
|
patch("vllm_ascend.ops.linear.get_tp_group",
|
|
return_value=self.mock_group),
|
|
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
|
|
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
|
|
]
|
|
|
|
for p in self.patches:
|
|
p.start()
|
|
|
|
def tearDown(self):
|
|
for p in self.patches:
|
|
p.stop()
|
|
|
|
|
|
class TestAscendRowParallelLinear(BaseLinearTest):
|
|
|
|
def test_mlp_optimize(self):
|
|
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
|
|
|
linear = AscendRowParallelLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
prefix="down_proj",
|
|
)
|
|
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
|
self.assertEqual(linear.forward_type, "mlp_tp")
|
|
|
|
input_tensor = torch.randn(16, 8)
|
|
linear(input_tensor)
|
|
|
|
def test_oproj_tp(self):
|
|
ascend_config._ASCEND_CONFIG = MagicMock()
|
|
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
|
|
|
|
linear = AscendRowParallelLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
prefix="o_proj",
|
|
)
|
|
self.assertEqual(linear.comm_group, parallel_state._OTP)
|
|
self.assertEqual(linear.forward_type, "oproj_tp")
|
|
|
|
input_tensor = torch.randn(16, 8)
|
|
linear(input_tensor)
|
|
|
|
|
|
class TestAscendColumnParallelLinear(BaseLinearTest):
|
|
|
|
def test_mlp_tp_init(self):
|
|
linear = AscendColumnParallelLinear(
|
|
input_size=16,
|
|
output_size=8,
|
|
prefix="down_proj",
|
|
)
|
|
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
|
|
|
|
|
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
|
|
|
|
def test_merged_mlp_tp_init(self):
|
|
linear = AscendMergedColumnParallelLinear(
|
|
input_size=16,
|
|
output_sizes=[8, 8],
|
|
prefix="gate_up_proj",
|
|
)
|
|
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
|
self.assertEqual(linear.forward_type, "mlp_tp")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|