[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
This commit is contained in:
@@ -4,8 +4,8 @@ import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, init_ascend_model_parallel)
|
||||
_LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, get_otp_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,16 +29,20 @@ def mock_distributed():
|
||||
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
assert mc2_group is not None
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
otp_group = get_otp_group()
|
||||
assert mc2_group is not None
|
||||
assert otp_group is not None
|
||||
assert lmheadtp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
assert _OTP is None
|
||||
|
||||
@@ -174,7 +174,6 @@ def test_row_parallel_linear(cls, mock_distributed):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
|
||||
@@ -286,6 +286,18 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
|
||||
return_value=2,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.linear.divide",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.linear.get_tp_group",
|
||||
return_value=mock_group,
|
||||
)
|
||||
|
||||
vision_transformer = AscendQwen2_5_VisionTransformer(
|
||||
vision_config,
|
||||
|
||||
@@ -1,363 +1,105 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear, LinearBase,
|
||||
QuantizationConfig)
|
||||
from vllm_ascend import ascend_config
|
||||
from vllm_ascend.distributed import parallel_state
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
|
||||
class TestAscendMlpRowParallelLinear(unittest.TestCase):
|
||||
class BaseLinearTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
self.tensor_parallel_world_size = 2
|
||||
self.tensor_parallel_rank = 0
|
||||
self.mlp_tensor_parallel_world_size = 2
|
||||
self.mlp_tensor_parallel_rank = 1
|
||||
self.mock_group = mock.MagicMock()
|
||||
self.mock_group.world_size = 2
|
||||
self.mock_group.rank_in_group = 0
|
||||
|
||||
self.get_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
|
||||
return_value=self.tensor_parallel_world_size)
|
||||
self.get_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
|
||||
return_value=self.tensor_parallel_rank)
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
|
||||
return_value=self.mlp_tensor_parallel_world_size)
|
||||
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
|
||||
return_value=self.mlp_tensor_parallel_rank)
|
||||
parallel_state._MLP_TP = self.mock_group
|
||||
parallel_state._OTP = self.mock_group
|
||||
|
||||
self.get_tensor_model_parallel_world_size_mock = \
|
||||
self.get_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_tensor_model_parallel_rank_mock = \
|
||||
self.get_tensor_model_parallel_rank_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_world_size_mock = \
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_rank_mock = \
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.start()
|
||||
self.mock_ascend_config = MagicMock()
|
||||
self.mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
|
||||
self.split_tensor_along_last_dim_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
|
||||
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
|
||||
self.tensor_model_parallel_all_reduce_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_reduce_mock = \
|
||||
self.tensor_model_parallel_all_reduce_patch.start()
|
||||
self.split_tensor_along_last_dim_mock = \
|
||||
self.split_tensor_along_last_dim_patch.start()
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
|
||||
mock.MagicMock()
|
||||
self.patches = [
|
||||
patch("vllm_ascend.ascend_config.get_ascend_config",
|
||||
return_value=self.mock_ascend_config),
|
||||
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.ops.linear.get_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
|
||||
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.get_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_tensor_model_parallel_rank_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.stop()
|
||||
self.split_tensor_along_last_dim_patch.stop()
|
||||
self.tensor_model_parallel_all_reduce_patch.stop()
|
||||
self.get_mlp_tp_group_patch.stop()
|
||||
for p in self.patches:
|
||||
p.stop()
|
||||
|
||||
def test_init_with_down_proj_prefix(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj")
|
||||
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
|
||||
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
|
||||
def test_forward_with_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
class TestAscendRowParallelLinear(BaseLinearTest):
|
||||
|
||||
def test_mlp_optimize(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
|
||||
linear = AscendRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
|
||||
layer(input_tensor)
|
||||
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
||||
self.assertEqual(linear.forward_type, "mlp_tp")
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
linear(input_tensor)
|
||||
|
||||
def test_forward_without_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
def test_oproj_tp(self):
|
||||
ascend_config._ASCEND_CONFIG = MagicMock()
|
||||
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
|
||||
|
||||
linear = AscendRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="other",
|
||||
input_is_parallel=False,
|
||||
prefix="o_proj",
|
||||
)
|
||||
self.assertEqual(linear.comm_group, parallel_state._OTP)
|
||||
self.assertEqual(linear.forward_type, "oproj_tp")
|
||||
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
linear(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
|
||||
|
||||
def test_skip_bias_add(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
class TestAscendColumnParallelLinear(BaseLinearTest):
|
||||
|
||||
def test_mlp_tp_init(self):
|
||||
linear = AscendColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
skip_bias_add=True,
|
||||
prefix="down_proj",
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
self.assertIsNotNone(bias)
|
||||
|
||||
def test_no_reduce_results(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
|
||||
|
||||
def test_input_not_parallel(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
input_is_parallel=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once()
|
||||
|
||||
def test_exception_when_reduce_false_and_bias(self):
|
||||
with self.assertRaises(ValueError):
|
||||
AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=True,
|
||||
skip_bias_add=False)
|
||||
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
||||
|
||||
|
||||
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
|
||||
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock distributed functions
|
||||
self.mlp_tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
|
||||
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
|
||||
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
|
||||
|
||||
self.mlp_tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
|
||||
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
|
||||
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
|
||||
|
||||
self.tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
|
||||
self.tp_size_mock = self.tp_size_patch.start()
|
||||
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
|
||||
|
||||
self.tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
|
||||
self.tp_rank_mock = self.tp_rank_patch.start()
|
||||
self.tp_rank_mock.return_value = 1 # Current GPU rank
|
||||
|
||||
# Mock divide function (assumed to be in your module)
|
||||
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
|
||||
self.divide_mock = self.divide_patch.start()
|
||||
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
|
||||
|
||||
# Mock QuantizationConfig and QuantMethod
|
||||
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
|
||||
|
||||
# Mock LinearBase initialization
|
||||
self.linear_base_init_patch = mock.patch.object(
|
||||
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
|
||||
self.linear_base_init_patch.start()
|
||||
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
|
||||
def mock_linear_base_init(self, instance, *args, **kwargs):
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.params_dtype = mock.MagicMock()
|
||||
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.output_size_per_partition = 4
|
||||
instance.params_dtype = torch.float32
|
||||
|
||||
def tearDown(self):
|
||||
self.mlp_tp_size_patch.stop()
|
||||
self.mlp_tp_rank_patch.stop()
|
||||
self.tp_size_patch.stop()
|
||||
self.tp_rank_patch.stop()
|
||||
self.divide_patch.stop()
|
||||
self.linear_base_init_patch.stop()
|
||||
|
||||
def test_mlp_optimize_initialization(self):
|
||||
# Test when prefix contains "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.gate_up_proj",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify MLP optimization flags
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 2)
|
||||
self.assertEqual(layer.tp_rank, 0)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_regular_parallel_initialization(self):
|
||||
# Test when prefix does NOT contain "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.q_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify regular TP flags
|
||||
self.assertFalse(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 4)
|
||||
self.assertEqual(layer.tp_rank, 1)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_output_sizes_handling(self):
|
||||
# Test when output_sizes is provided
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
output_sizes=[4, 4],
|
||||
prefix="model.layers.0.qkv_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify output_partition_sizes
|
||||
self.assertEqual(layer.output_partition_sizes, [2])
|
||||
def test_merged_mlp_tp_init(self):
|
||||
linear = AscendMergedColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_sizes=[8, 8],
|
||||
prefix="gate_up_proj",
|
||||
)
|
||||
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
|
||||
self.assertEqual(linear.forward_type, "mlp_tp")
|
||||
|
||||
|
||||
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
|
||||
self.mlp_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
|
||||
self.tensor_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
|
||||
self.mlp_world_size_patch.start()
|
||||
self.tensor_world_size_patch.start()
|
||||
|
||||
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
|
||||
self.mlp_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
|
||||
self.tensor_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
|
||||
self.mlp_rank_patch.start()
|
||||
self.tensor_rank_patch.start()
|
||||
|
||||
# Mock all_gather methods
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
|
||||
self.tensor_model_parallel_all_gather_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_gather_mock = \
|
||||
self.tensor_model_parallel_all_gather_patch.start()
|
||||
|
||||
# Mock AscendMlpColumnParallelLinear's __init__
|
||||
self.linear_init_patch = mock.patch.object(
|
||||
AscendMlpColumnParallelLinear,
|
||||
"__init__",
|
||||
side_effect=self.mock_linear_init)
|
||||
self.linear_init_patch.start()
|
||||
|
||||
# Create mock objects
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
self.apply_output = torch.randn(2, 8)
|
||||
|
||||
self.quant_method_mock.apply.return_value = self.apply_output
|
||||
|
||||
def mock_linear_init(self, instance, *args, **kwargs):
|
||||
torch.nn.Module.__init__(instance)
|
||||
# Set quant_method and other attributes
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.gather_output = False
|
||||
instance.skip_bias_add = False
|
||||
instance.return_bias = True
|
||||
|
||||
def test_forward_with_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def test_forward_without_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix not containing "gate_up_proj"
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.quant_method_mock.apply.assert_called_once_with(
|
||||
layer, input_tensor, layer.bias)
|
||||
self.tensor_model_parallel_all_gather_mock.assert_not_called()
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def tearDown(self):
|
||||
self.linear_init_patch.stop()
|
||||
self.mlp_world_size_patch.stop()
|
||||
self.tensor_world_size_patch.stop()
|
||||
self.mlp_rank_patch.stop()
|
||||
self.tensor_rank_patch.stop()
|
||||
self.get_mlp_tp_group_mock.stop()
|
||||
self.tensor_model_parallel_all_gather_mock.stop()
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -206,7 +206,15 @@ class TestAscendLogitsProcessor(unittest.TestCase):
|
||||
return_value=True),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
return_value=torch.randn(1, self.vocab_size)),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather",
|
||||
return_value=torch.randn(1, self.vocab_size)),
|
||||
patch(
|
||||
"vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config",
|
||||
return_value=MagicMock(max_num_batched_tokens=1000,
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=False))
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
from importlib import reload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import vllm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.patch.worker.patch_common import patch_linear
|
||||
|
||||
|
||||
class TestAscendRowParallelLinear(PytestBase):
|
||||
|
||||
def init_row_parallel_linear(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
return patch_linear.AscendRowParallelLinear(
|
||||
input_size=128,
|
||||
output_size=256,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version, expected",
|
||||
[
|
||||
("1.0.0", 1),
|
||||
("2.1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
|
||||
mock_group = mocker.MagicMock()
|
||||
backend = mocker.MagicMock()
|
||||
backend.get_hccl_comm_name = lambda x: x
|
||||
mock_group._get_backend = lambda x: backend
|
||||
mock_group.get_hccl_comm_name = lambda x: x
|
||||
mocker.patch("torch.distributed.get_rank", return_value=1)
|
||||
mocker.patch(
|
||||
"torch.distributed.get_global_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("torch.__version__", new=version)
|
||||
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
|
||||
mock_group)
|
||||
assert hcomm_info == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_bias_add, return_bias, bias, expected",
|
||||
[
|
||||
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
|
||||
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
|
||||
(
|
||||
True,
|
||||
True,
|
||||
torch.tensor(1.0),
|
||||
(torch.tensor(14.0), torch.tensor(1.0)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_forward(
|
||||
self,
|
||||
skip_bias_add,
|
||||
return_bias,
|
||||
bias,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
mocker_tp_group = mocker.MagicMock()
|
||||
mocker_tp_group.device_group = mocker.MagicMock()
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
|
||||
row_parallel_linear.__dict__["return_bias"] = return_bias
|
||||
row_parallel_linear.__dict__["bias"] = bias
|
||||
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
|
||||
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
|
||||
row_parallel_linear.__dict__[
|
||||
"calc_output"] = lambda x: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
|
||||
if isinstance(ret, tuple):
|
||||
assert torch.allclose(ret[0], expected[0])
|
||||
if ret[1] is None:
|
||||
assert ret[1] == expected[1]
|
||||
else:
|
||||
assert torch.allclose(ret[1], expected[1])
|
||||
else:
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_is_parallel, expected",
|
||||
[
|
||||
(True, torch.tensor([10.0, 2.0])),
|
||||
(False, torch.tensor([10.0])),
|
||||
],
|
||||
)
|
||||
def test_calc_input(
|
||||
self,
|
||||
input_is_parallel,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
|
||||
input_tensor = torch.Tensor([10, 2])
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
|
||||
return_value=[torch.Tensor([10]),
|
||||
torch.Tensor([2])],
|
||||
)
|
||||
input_parallel = row_parallel_linear.calc_input(input_tensor)
|
||||
assert torch.allclose(input_parallel, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reduce_results, tp_size, expected",
|
||||
[
|
||||
(True, 2, torch.tensor(56.0)),
|
||||
(True, 1, torch.tensor(14.0)),
|
||||
(False, 2, torch.tensor(14.0)),
|
||||
],
|
||||
)
|
||||
def test_calc_output(
|
||||
self,
|
||||
reduce_results,
|
||||
tp_size,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
quant_method = mocker.MagicMock()
|
||||
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["reduce_results"] = reduce_results
|
||||
row_parallel_linear.__dict__["tp_size"] = tp_size
|
||||
row_parallel_linear.__dict__["quant_method"] = quant_method
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
|
||||
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch(
|
||||
"torch_npu.npu_mm_all_reduce_base",
|
||||
side_effect=lambda input_, weight, hccl_info, bias: input_.
|
||||
matmul( # noqa
|
||||
torch.tensor([4.0, 8.0])),
|
||||
) # noqa
|
||||
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
|
||||
mocker.patch.object(envs_ascend,
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
|
||||
new=True)
|
||||
reload(patch_linear)
|
||||
assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
|
||||
patch_linear.AscendRowParallelLinear)
|
||||
@@ -11,8 +11,19 @@ from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
|
||||
@@ -18,17 +18,37 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
mock_rank = Mock()
|
||||
mock_get_rank.return_value = mock_rank
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
|
||||
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
||||
# 创建一个具有具体属性的 Mock 对象来表示 ascend_scheduler_config
|
||||
mock_ascend_scheduler_config = Mock()
|
||||
mock_ascend_scheduler_config.enabled = False
|
||||
mock_ascend_scheduler_config.max_num_batched_tokens = 1024
|
||||
mock_ascend_scheduler_config.max_model_len = 2048
|
||||
mock_ascend_config.ascend_scheduler_config = mock_ascend_scheduler_config
|
||||
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_ascend_config.enable_chunked_prefill = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
mock_rank = Mock()
|
||||
mock_get_rank.return_value = mock_rank
|
||||
|
||||
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
param_dict = self.quant_method.get_weight(self.num_experts,
|
||||
|
||||
@@ -359,3 +359,27 @@ class TestAscendConfig(TestBase):
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=1)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@@ -100,6 +100,11 @@ def mock_distributed():
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mlp_tp_group = Mock(spec=GroupCoordinator)
|
||||
mlp_tp_group.rank_in_group = 0
|
||||
mlp_tp_group.world_size = 1
|
||||
mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
@@ -196,10 +201,6 @@ def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
@@ -322,4 +323,4 @@ def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
|
||||
):
|
||||
loaded = model.load_weights(weights)
|
||||
assert loaded is not None
|
||||
assert loaded is not None
|
||||
Reference in New Issue
Block a user