From 5a7181569c58630fed55e1f78eec615dc2034b74 Mon Sep 17 00:00:00 2001
From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com>
Date: Sun, 7 Sep 2025 10:31:32 +0800
Subject: [PATCH] [feat]: oproj tensor parallelism in pure DP and graph-mode
scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c
---------
Signed-off-by: zzhx1
Co-authored-by: zzh
---
.../configuration/additional_config.md | 1 +
tests/ut/distributed/test_parallel_state.py | 10 +-
tests/ut/models/test_deepseek_v2.py | 1 -
tests/ut/models/test_qwen2_5_vl.py | 12 +
tests/ut/ops/test_linear.py | 394 +++---------------
tests/ut/ops/test_vocab_parallel_embedding.py | 10 +-
.../worker/patch_common/test_patch_linear.py | 167 --------
tests/ut/quantization/test_w4a8_dynamic.py | 15 +-
tests/ut/quantization/test_w8a8_dynamic.py | 40 +-
tests/ut/test_ascend_config.py | 24 ++
.../models/test_torchair_deepseek_v2.py | 11 +-
vllm_ascend/ascend_config.py | 18 +
vllm_ascend/distributed/parallel_state.py | 27 +-
vllm_ascend/ops/linear.py | 360 +++++++++++-----
.../patch/worker/patch_common/__init__.py | 3 +-
.../patch/worker/patch_common/patch_linear.py | 147 -------
.../patch/worker/patch_common/patch_lora.py | 15 +
.../patch_common/patch_lora_embedding.py | 7 -
.../worker/patch_common/patch_lora_linear.py | 52 +++
vllm_ascend/platform.py | 1 -
vllm_ascend/quantization/quant_config.py | 17 +-
.../torchair/models/torchair_deepseek_v2.py | 19 +-
vllm_ascend/utils.py | 32 +-
23 files changed, 576 insertions(+), 807 deletions(-)
delete mode 100644 tests/ut/patch/worker/patch_common/test_patch_linear.py
delete mode 100644 vllm_ascend/patch/worker/patch_common/patch_linear.py
create mode 100644 vllm_ascend/patch/worker/patch_common/patch_lora.py
create mode 100644 vllm_ascend/patch/worker/patch_common/patch_lora_linear.py
diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md
index c67f340..81c5ef6 100644
--- a/docs/source/user_guide/configuration/additional_config.md
+++ b/docs/source/user_guide/configuration/additional_config.md
@@ -35,6 +35,7 @@ The following table lists the additional configuration options available in vLLM
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
+| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
The details of each config option are as follows:
diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py
index afc22c8..6b52b7b 100644
--- a/tests/ut/distributed/test_parallel_state.py
+++ b/tests/ut/distributed/test_parallel_state.py
@@ -4,8 +4,8 @@ import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
- _LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
- get_mc2_group, init_ascend_model_parallel)
+ _LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
+ get_mc2_group, get_otp_group, init_ascend_model_parallel)
@pytest.fixture
@@ -29,16 +29,20 @@ def mock_distributed():
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
+ mock_ascend_config.oproj_tensor_parallel_size = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
- assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
+ otp_group = get_otp_group()
+ assert mc2_group is not None
+ assert otp_group is not None
assert lmheadtp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
+ assert _OTP is None
diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py
index df14a2a..c6eeb97 100644
--- a/tests/ut/models/test_deepseek_v2.py
+++ b/tests/ut/models/test_deepseek_v2.py
@@ -174,7 +174,6 @@ def test_row_parallel_linear(cls, mock_distributed):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
-
input_ = torch.randn(2, 4, 128)
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py
index 15367eb..d33f337 100644
--- a/tests/ut/models/test_qwen2_5_vl.py
+++ b/tests/ut/models/test_qwen2_5_vl.py
@@ -286,6 +286,18 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
return_value=2,
)
+ mocker.patch(
+ "vllm_ascend.ops.linear.divide",
+ return_value=2,
+ )
+
+ mock_group = mocker.MagicMock()
+ mock_group.rank_in_group = 0
+ mock_group.world_size = 2
+ mocker.patch(
+ "vllm_ascend.ops.linear.get_tp_group",
+ return_value=mock_group,
+ )
vision_transformer = AscendQwen2_5_VisionTransformer(
vision_config,
diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py
index 28b26b7..a0d7f06 100644
--- a/tests/ut/ops/test_linear.py
+++ b/tests/ut/ops/test_linear.py
@@ -1,363 +1,105 @@
import os
import unittest
from unittest import mock
+from unittest.mock import MagicMock, patch
import torch
-from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
- AscendMlpMergedColumnParallelLinear,
- AscendMlpRowParallelLinear, LinearBase,
- QuantizationConfig)
+from vllm_ascend import ascend_config
+from vllm_ascend.distributed import parallel_state
+from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
+ AscendMergedColumnParallelLinear,
+ AscendRowParallelLinear)
-class TestAscendMlpRowParallelLinear(unittest.TestCase):
+class BaseLinearTest(unittest.TestCase):
def setUp(self):
- os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
- self.tensor_parallel_world_size = 2
- self.tensor_parallel_rank = 0
- self.mlp_tensor_parallel_world_size = 2
- self.mlp_tensor_parallel_rank = 1
+ self.mock_group = mock.MagicMock()
+ self.mock_group.world_size = 2
+ self.mock_group.rank_in_group = 0
- self.get_tensor_model_parallel_world_size_patch = mock.patch(
- 'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
- return_value=self.tensor_parallel_world_size)
- self.get_tensor_model_parallel_rank_patch = mock.patch(
- 'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
- return_value=self.tensor_parallel_rank)
- self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
- 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
- return_value=self.mlp_tensor_parallel_world_size)
- self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
- 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
- return_value=self.mlp_tensor_parallel_rank)
+ parallel_state._MLP_TP = self.mock_group
+ parallel_state._OTP = self.mock_group
- self.get_tensor_model_parallel_world_size_mock = \
- self.get_tensor_model_parallel_world_size_patch.start()
- self.get_tensor_model_parallel_rank_mock = \
- self.get_tensor_model_parallel_rank_patch.start()
- self.get_mlp_tensor_model_parallel_world_size_mock = \
- self.get_mlp_tensor_model_parallel_world_size_patch.start()
- self.get_mlp_tensor_model_parallel_rank_mock = \
- self.get_mlp_tensor_model_parallel_rank_patch.start()
+ self.mock_ascend_config = MagicMock()
+ self.mock_ascend_config.oproj_tensor_parallel_size = 2
- self.split_tensor_along_last_dim_patch = mock.patch(
- 'vllm_ascend.ops.linear.split_tensor_along_last_dim',
- return_value=(torch.randn(10, 8), torch.randn(10, 8)))
- self.tensor_model_parallel_all_reduce_patch = mock.patch(
- 'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
- return_value=torch.randn(10, 8))
- self.tensor_model_parallel_all_reduce_mock = \
- self.tensor_model_parallel_all_reduce_patch.start()
- self.split_tensor_along_last_dim_mock = \
- self.split_tensor_along_last_dim_patch.start()
- self.get_mlp_tp_group_patch = \
- mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
- self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
- self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
- self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
- mock.MagicMock()
+ self.patches = [
+ patch("vllm_ascend.ascend_config.get_ascend_config",
+ return_value=self.mock_ascend_config),
+ patch("vllm_ascend.distributed.parallel_state.get_otp_group",
+ return_value=self.mock_group),
+ patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
+ return_value=self.mock_group),
+ patch("vllm_ascend.ops.linear.get_tp_group",
+ return_value=self.mock_group),
+ patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
+ patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
+ ]
+
+ for p in self.patches:
+ p.start()
def tearDown(self):
- self.get_tensor_model_parallel_world_size_patch.stop()
- self.get_tensor_model_parallel_rank_patch.stop()
- self.get_mlp_tensor_model_parallel_world_size_patch.stop()
- self.get_mlp_tensor_model_parallel_rank_patch.stop()
- self.split_tensor_along_last_dim_patch.stop()
- self.tensor_model_parallel_all_reduce_patch.stop()
- self.get_mlp_tp_group_patch.stop()
+ for p in self.patches:
+ p.stop()
- def test_init_with_down_proj_prefix(self):
- layer = AscendMlpRowParallelLinear(input_size=16,
- output_size=8,
- prefix="down_proj")
- self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
- self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
- self.assertTrue(layer.enable_mlp_optimze)
- def test_forward_with_mlp_optimize(self):
- layer = AscendMlpRowParallelLinear(
+class TestAscendRowParallelLinear(BaseLinearTest):
+
+ def test_mlp_optimize(self):
+ os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
+
+ linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
- input_is_parallel=False,
)
- input_tensor = torch.randn(16, 8) # (batch_size, input_size)
- layer(input_tensor)
+ self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
+ self.assertEqual(linear.forward_type, "mlp_tp")
- self.split_tensor_along_last_dim_mock.assert_called_once_with(
- input_tensor, num_partitions=layer.tp_size)
+ input_tensor = torch.randn(16, 8)
+ linear(input_tensor)
- def test_forward_without_mlp_optimize(self):
- layer = AscendMlpRowParallelLinear(
+ def test_oproj_tp(self):
+ ascend_config._ASCEND_CONFIG = MagicMock()
+ ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
+
+ linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
- prefix="other",
- input_is_parallel=False,
+ prefix="o_proj",
)
+ self.assertEqual(linear.comm_group, parallel_state._OTP)
+ self.assertEqual(linear.forward_type, "oproj_tp")
+
input_tensor = torch.randn(16, 8)
- layer(input_tensor)
+ linear(input_tensor)
- self.split_tensor_along_last_dim_mock.assert_called_once_with(
- input_tensor, num_partitions=layer.tp_size)
- self.tensor_model_parallel_all_reduce_mock.assert_called_once()
- def test_skip_bias_add(self):
- layer = AscendMlpRowParallelLinear(
+class TestAscendColumnParallelLinear(BaseLinearTest):
+
+ def test_mlp_tp_init(self):
+ linear = AscendColumnParallelLinear(
input_size=16,
output_size=8,
- skip_bias_add=True,
+ prefix="down_proj",
)
- input_tensor = torch.randn(16, 8)
- output, bias = layer(input_tensor)
-
- self.assertIsNotNone(bias)
-
- def test_no_reduce_results(self):
- layer = AscendMlpRowParallelLinear(input_size=16,
- output_size=8,
- reduce_results=False,
- bias=False)
- input_tensor = torch.randn(16, 8)
- layer(input_tensor)
-
- self.tensor_model_parallel_all_reduce_mock.assert_not_called()
-
- def test_input_not_parallel(self):
- layer = AscendMlpRowParallelLinear(input_size=16,
- output_size=8,
- input_is_parallel=False)
- input_tensor = torch.randn(16, 8)
- layer(input_tensor)
-
- self.split_tensor_along_last_dim_mock.assert_called_once()
-
- def test_exception_when_reduce_false_and_bias(self):
- with self.assertRaises(ValueError):
- AscendMlpRowParallelLinear(input_size=16,
- output_size=8,
- reduce_results=False,
- bias=True,
- skip_bias_add=False)
+ self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
-class TestAscendMlpColumnParallelLinear(unittest.TestCase):
+class TestAscendMergedColumnParallelLinear(BaseLinearTest):
- def setUp(self):
- os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
- # Mock distributed functions
- self.mlp_tp_size_patch = \
- mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
- self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
- self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
-
- self.mlp_tp_rank_patch = \
- mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
- self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
- self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
-
- self.tp_size_patch = \
- mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
- self.tp_size_mock = self.tp_size_patch.start()
- self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
-
- self.tp_rank_patch = \
- mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
- self.tp_rank_mock = self.tp_rank_patch.start()
- self.tp_rank_mock.return_value = 1 # Current GPU rank
-
- # Mock divide function (assumed to be in your module)
- self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
- self.divide_mock = self.divide_patch.start()
- self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
-
- # Mock QuantizationConfig and QuantMethod
- self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
-
- # Mock LinearBase initialization
- self.linear_base_init_patch = mock.patch.object(
- LinearBase, "__init__", side_effect=self.mock_linear_base_init)
- self.linear_base_init_patch.start()
-
- self.quant_method_mock = mock.MagicMock()
-
- def mock_linear_base_init(self, instance, *args, **kwargs):
- instance.quant_method = self.quant_method_mock
- instance.params_dtype = mock.MagicMock()
-
- instance.input_size = 16
- instance.output_size = 8
- instance.output_size_per_partition = 4
- instance.params_dtype = torch.float32
-
- def tearDown(self):
- self.mlp_tp_size_patch.stop()
- self.mlp_tp_rank_patch.stop()
- self.tp_size_patch.stop()
- self.tp_rank_patch.stop()
- self.divide_patch.stop()
- self.linear_base_init_patch.stop()
-
- def test_mlp_optimize_initialization(self):
- # Test when prefix contains "gate_up_proj"
- with mock.patch.object(torch.nn.Module, 'register_parameter'):
- layer = AscendMlpColumnParallelLinear(
- input_size=16,
- output_size=8,
- prefix="model.layers.0.gate_up_proj",
- bias=False,
- )
-
- # Verify MLP optimization flags
- self.assertTrue(layer.enable_mlp_optimze)
- self.assertEqual(layer.tp_size, 2)
- self.assertEqual(layer.tp_rank, 0)
- self.assertEqual(layer.input_size_per_partition, 16)
- self.assertEqual(layer.output_size_per_partition, 4)
-
- # Check quant_method.create_weights was called
- self.quant_method_mock.create_weights.assert_called_once()
-
- def test_regular_parallel_initialization(self):
- # Test when prefix does NOT contain "gate_up_proj"
- with mock.patch.object(torch.nn.Module, 'register_parameter'):
- layer = AscendMlpColumnParallelLinear(
- input_size=16,
- output_size=8,
- prefix="model.layers.0.q_proj",
- quant_config=self.quant_config_mock,
- bias=False,
- )
-
- # Verify regular TP flags
- self.assertFalse(layer.enable_mlp_optimze)
- self.assertEqual(layer.tp_size, 4)
- self.assertEqual(layer.tp_rank, 1)
- self.assertEqual(layer.input_size_per_partition, 16)
- self.assertEqual(layer.output_size_per_partition, 4)
- # Check quant_method.create_weights was called
- self.quant_method_mock.create_weights.assert_called_once()
-
- def test_output_sizes_handling(self):
- # Test when output_sizes is provided
- with mock.patch.object(torch.nn.Module, 'register_parameter'):
- layer = AscendMlpColumnParallelLinear(
- input_size=16,
- output_size=8,
- output_sizes=[4, 4],
- prefix="model.layers.0.qkv_proj",
- quant_config=self.quant_config_mock,
- bias=False,
- )
-
- # Verify output_partition_sizes
- self.assertEqual(layer.output_partition_sizes, [2])
+ def test_merged_mlp_tp_init(self):
+ linear = AscendMergedColumnParallelLinear(
+ input_size=16,
+ output_sizes=[8, 8],
+ prefix="gate_up_proj",
+ )
+ self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
+ self.assertEqual(linear.forward_type, "mlp_tp")
-class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
-
- def setUp(self):
- os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
- # Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
- self.mlp_world_size_patch = \
- mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
- self.tensor_world_size_patch = \
- mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
- self.mlp_world_size_patch.start()
- self.tensor_world_size_patch.start()
-
- # Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
- self.mlp_rank_patch = \
- mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
- self.tensor_rank_patch = \
- mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
- self.mlp_rank_patch.start()
- self.tensor_rank_patch.start()
-
- # Mock all_gather methods
- self.get_mlp_tp_group_patch = \
- mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
- self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
- self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
- self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
- self.tensor_model_parallel_all_gather_patch = mock.patch(
- 'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
- return_value=torch.randn(10, 8))
- self.tensor_model_parallel_all_gather_mock = \
- self.tensor_model_parallel_all_gather_patch.start()
-
- # Mock AscendMlpColumnParallelLinear's __init__
- self.linear_init_patch = mock.patch.object(
- AscendMlpColumnParallelLinear,
- "__init__",
- side_effect=self.mock_linear_init)
- self.linear_init_patch.start()
-
- # Create mock objects
- self.quant_method_mock = mock.MagicMock()
- self.apply_output = torch.randn(2, 8)
-
- self.quant_method_mock.apply.return_value = self.apply_output
-
- def mock_linear_init(self, instance, *args, **kwargs):
- torch.nn.Module.__init__(instance)
- # Set quant_method and other attributes
- instance.quant_method = self.quant_method_mock
- instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
- instance.input_size = 16
- instance.output_size = 8
- instance.gather_output = False
- instance.skip_bias_add = False
- instance.return_bias = True
-
- def test_forward_with_enable_mlp_optimze(self):
- # Setup input
- input_tensor = torch.randn(1, 16)
-
- # Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
- layer = AscendMlpMergedColumnParallelLinear(input_size=16,
- output_sizes=[8],
- bias=True,
- gather_output=False,
- skip_bias_add=False,
- params_dtype=torch.float32,
- quant_config=None,
- prefix="other_proj")
-
- # Call forward
- output, bias = layer(input_tensor)
-
- # Validate calls
- self.assertEqual(output.shape, self.apply_output.shape)
-
- def test_forward_without_enable_mlp_optimze(self):
- # Setup input
- input_tensor = torch.randn(1, 16)
-
- # Create instance with prefix not containing "gate_up_proj"
- layer = AscendMlpMergedColumnParallelLinear(input_size=16,
- output_sizes=[8],
- bias=True,
- gather_output=False,
- skip_bias_add=False,
- params_dtype=torch.float32,
- quant_config=None,
- prefix="other_proj")
-
- # Call forward
- output, bias = layer(input_tensor)
-
- # Validate calls
- self.quant_method_mock.apply.assert_called_once_with(
- layer, input_tensor, layer.bias)
- self.tensor_model_parallel_all_gather_mock.assert_not_called()
- self.assertEqual(output.shape, self.apply_output.shape)
-
- def tearDown(self):
- self.linear_init_patch.stop()
- self.mlp_world_size_patch.stop()
- self.tensor_world_size_patch.stop()
- self.mlp_rank_patch.stop()
- self.tensor_rank_patch.stop()
- self.get_mlp_tp_group_mock.stop()
- self.tensor_model_parallel_all_gather_mock.stop()
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py
index 5378b19..66163f5 100644
--- a/tests/ut/ops/test_vocab_parallel_embedding.py
+++ b/tests/ut/ops/test_vocab_parallel_embedding.py
@@ -206,7 +206,15 @@ class TestAscendLogitsProcessor(unittest.TestCase):
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
- return_value=torch.randn(1, self.vocab_size))
+ return_value=torch.randn(1, self.vocab_size)),
+ patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather",
+ return_value=torch.randn(1, self.vocab_size)),
+ patch(
+ "vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config",
+ return_value=MagicMock(max_num_batched_tokens=1000,
+ max_model_len=512,
+ enable_chunked_prefill=False))
]
for p in self.patches:
diff --git a/tests/ut/patch/worker/patch_common/test_patch_linear.py b/tests/ut/patch/worker/patch_common/test_patch_linear.py
deleted file mode 100644
index b7fbbc4..0000000
--- a/tests/ut/patch/worker/patch_common/test_patch_linear.py
+++ /dev/null
@@ -1,167 +0,0 @@
-from importlib import reload
-
-import pytest
-import torch
-import vllm
-from pytest_mock import MockerFixture
-
-import vllm_ascend.envs as envs_ascend
-from tests.ut.base import PytestBase
-from vllm_ascend.patch.worker.patch_common import patch_linear
-
-
-class TestAscendRowParallelLinear(PytestBase):
-
- def init_row_parallel_linear(self, mocker: MockerFixture):
- mocker.patch(
- "vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
- return_value=None,
- )
- mocker.patch("torch.nn.Module.__setattr__")
- mocker.patch("torch.nn.Module.__getattr__")
- mocker.patch("torch.nn.Module.__delattr__")
- return patch_linear.AscendRowParallelLinear(
- input_size=128,
- output_size=256,
- )
-
- @pytest.mark.parametrize(
- "version, expected",
- [
- ("1.0.0", 1),
- ("2.1.0", 1),
- ],
- )
- def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
- mock_group = mocker.MagicMock()
- backend = mocker.MagicMock()
- backend.get_hccl_comm_name = lambda x: x
- mock_group._get_backend = lambda x: backend
- mock_group.get_hccl_comm_name = lambda x: x
- mocker.patch("torch.distributed.get_rank", return_value=1)
- mocker.patch(
- "torch.distributed.get_global_rank",
- return_value=0,
- )
- mocker.patch("torch.__version__", new=version)
- hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
- mock_group)
- assert hcomm_info == expected
-
- @pytest.mark.parametrize(
- "skip_bias_add, return_bias, bias, expected",
- [
- (True, False, torch.tensor(1.0), torch.tensor(14.0)),
- (False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
- (
- True,
- True,
- torch.tensor(1.0),
- (torch.tensor(14.0), torch.tensor(1.0)),
- ),
- ],
- )
- def test_forward(
- self,
- skip_bias_add,
- return_bias,
- bias,
- expected,
- mocker: MockerFixture,
- ):
- mocker_tp_group = mocker.MagicMock()
- mocker_tp_group.device_group = mocker.MagicMock()
- row_parallel_linear = self.init_row_parallel_linear(mocker)
- row_parallel_linear.__dict__["tp_rank"] = 0
- row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
- row_parallel_linear.__dict__["return_bias"] = return_bias
- row_parallel_linear.__dict__["bias"] = bias
- row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
- row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
- row_parallel_linear.__dict__[
- "calc_output"] = lambda x: x.matmul( # noqa
- torch.tensor([1.0, 2.0]))
- ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
- if isinstance(ret, tuple):
- assert torch.allclose(ret[0], expected[0])
- if ret[1] is None:
- assert ret[1] == expected[1]
- else:
- assert torch.allclose(ret[1], expected[1])
- else:
- assert torch.allclose(ret, expected)
-
- @pytest.mark.parametrize(
- "input_is_parallel, expected",
- [
- (True, torch.tensor([10.0, 2.0])),
- (False, torch.tensor([10.0])),
- ],
- )
- def test_calc_input(
- self,
- input_is_parallel,
- expected,
- mocker: MockerFixture,
- ):
- row_parallel_linear = self.init_row_parallel_linear(mocker)
- row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
- input_tensor = torch.Tensor([10, 2])
- mocker.patch(
- "vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
- return_value=0,
- )
- mocker.patch(
- "vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
- return_value=[torch.Tensor([10]),
- torch.Tensor([2])],
- )
- input_parallel = row_parallel_linear.calc_input(input_tensor)
- assert torch.allclose(input_parallel, expected)
-
- @pytest.mark.parametrize(
- "reduce_results, tp_size, expected",
- [
- (True, 2, torch.tensor(56.0)),
- (True, 1, torch.tensor(14.0)),
- (False, 2, torch.tensor(14.0)),
- ],
- )
- def test_calc_output(
- self,
- reduce_results,
- tp_size,
- expected,
- mocker: MockerFixture,
- ):
- quant_method = mocker.MagicMock()
- quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
- torch.tensor([1.0, 2.0]))
- row_parallel_linear = self.init_row_parallel_linear(mocker)
- row_parallel_linear.__dict__["reduce_results"] = reduce_results
- row_parallel_linear.__dict__["tp_size"] = tp_size
- row_parallel_linear.__dict__["quant_method"] = quant_method
- row_parallel_linear.__dict__["tp_rank"] = 0
- row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
-
- mocker.patch(
- "vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
- return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
- )
- mocker.patch(
- "torch_npu.npu_mm_all_reduce_base",
- side_effect=lambda input_, weight, hccl_info, bias: input_.
- matmul( # noqa
- torch.tensor([4.0, 8.0])),
- ) # noqa
- ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
- assert torch.allclose(ret, expected)
-
- def test_enable_allreduce_matmul(self, mocker: MockerFixture):
- mocker.patch.object(envs_ascend,
- "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
- new=True)
- reload(patch_linear)
- assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
- assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
- patch_linear.AscendRowParallelLinear)
diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py
index d7fdf82..70256af 100644
--- a/tests/ut/quantization/test_w4a8_dynamic.py
+++ b/tests/ut/quantization/test_w4a8_dynamic.py
@@ -11,8 +11,19 @@ from vllm_ascend.quantization.w4a8_dynamic import (
class TestAscendW4A8DynamicLinearMethod(TestBase):
def setUp(self):
- self.method = AscendW4A8DynamicLinearMethod()
- self.method.group_size = 8
+ with patch(
+ 'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
+ ) as mock_get_current_vllm_config:
+ mock_vllm_config = Mock()
+ mock_vllm_config.quant_config = Mock(
+ quant_description={"group_size": 256})
+ mock_vllm_config.scheduler_config = Mock(
+ max_num_batched_tokens=2048,
+ max_model_len=2048,
+ enable_chunked_prefill=False)
+ mock_get_current_vllm_config.return_value = mock_vllm_config
+ self.method = AscendW4A8DynamicLinearMethod()
+ self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py
index 690778e..f25192c 100644
--- a/tests/ut/quantization/test_w8a8_dynamic.py
+++ b/tests/ut/quantization/test_w8a8_dynamic.py
@@ -18,17 +18,37 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
mock_get_mc2_group, mock_get_rank):
- mock_ep_group = Mock()
- mock_get_ep_group.return_value = mock_ep_group
- mock_ascend_config = Mock()
- mock_ascend_config.torchair_graph_config = Mock(enabled=False)
- mock_get_ascend_config.return_value = mock_ascend_config
- mock_mc2_group = Mock(device_group=0)
- mock_get_mc2_group.return_value = mock_mc2_group
- mock_rank = Mock()
- mock_get_rank.return_value = mock_rank
+ with patch(
+ 'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
+ ) as mock_get_current_vllm_config:
+ mock_vllm_config = Mock()
+ mock_vllm_config.quant_config = Mock(
+ quant_description={"group_size": 256})
+ mock_vllm_config.scheduler_config = Mock(
+ max_num_batched_tokens=2048,
+ max_model_len=2048,
+ enable_chunked_prefill=False)
+ mock_get_current_vllm_config.return_value = mock_vllm_config
+ mock_ep_group = Mock()
+ mock_get_ep_group.return_value = mock_ep_group
+ mock_ascend_config = Mock()
- self.quant_method = AscendW8A8DynamicFusedMoEMethod()
+ # 创建一个具有具体属性的 Mock 对象来表示 ascend_scheduler_config
+ mock_ascend_scheduler_config = Mock()
+ mock_ascend_scheduler_config.enabled = False
+ mock_ascend_scheduler_config.max_num_batched_tokens = 1024
+ mock_ascend_scheduler_config.max_model_len = 2048
+ mock_ascend_config.ascend_scheduler_config = mock_ascend_scheduler_config
+
+ mock_ascend_config.torchair_graph_config = Mock(enabled=False)
+ mock_ascend_config.enable_chunked_prefill = False
+ mock_get_ascend_config.return_value = mock_ascend_config
+ mock_mc2_group = Mock(device_group=0)
+ mock_get_mc2_group.return_value = mock_mc2_group
+ mock_rank = Mock()
+ mock_get_rank.return_value = mock_rank
+
+ self.quant_method = AscendW8A8DynamicFusedMoEMethod()
def test_get_weight(self):
param_dict = self.quant_method.get_weight(self.num_experts,
diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py
index 4c7cfa6..c8013fb 100644
--- a/tests/ut/test_ascend_config.py
+++ b/tests/ut/test_ascend_config.py
@@ -359,3 +359,27 @@ class TestAscendConfig(TestBase):
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
+
+ with self.assertRaises(AssertionError):
+ test_vllm_config.additional_config = {
+ "torchair_graph_config": {
+ "enabled": True,
+ },
+ "oproj_tensor_parallel_size": 2,
+ "refresh": True
+ }
+ test_vllm_config.parallel_config = ParallelConfig(
+ data_parallel_size=4, tensor_parallel_size=2)
+ init_ascend_config(test_vllm_config)
+
+ with self.assertRaises(AssertionError):
+ test_vllm_config.additional_config = {
+ "torchair_graph_config": {
+ "enabled": False,
+ },
+ "oproj_tensor_parallel_size": 2,
+ "refresh": True
+ }
+ test_vllm_config.parallel_config = ParallelConfig(
+ data_parallel_size=4, tensor_parallel_size=1)
+ init_ascend_config(test_vllm_config)
diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py
index e72d023..912ff9a 100644
--- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py
+++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py
@@ -100,6 +100,11 @@ def mock_distributed():
pp_group.rank_in_group = 0
pp_group.world_size = 1
+ mlp_tp_group = Mock(spec=GroupCoordinator)
+ mlp_tp_group.rank_in_group = 0
+ mlp_tp_group.world_size = 1
+ mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
+
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
@@ -196,10 +201,6 @@ def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None)
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
- x = torch.randn(2, 4, 128)
- output = mlp(x)
- assert output.shape == (2, 4, 128)
-
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
) as mock_quant_config:
@@ -322,4 +323,4 @@ def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
- assert loaded is not None
+ assert loaded is not None
\ No newline at end of file
diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py
index e46cd9a..d053387 100644
--- a/vllm_ascend/ascend_config.py
+++ b/vllm_ascend/ascend_config.py
@@ -61,6 +61,24 @@ class AscendConfig:
raise AssertionError(
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
)
+ self.oproj_tensor_parallel_size = additional_config.get(
+ "oproj_tensor_parallel_size", None)
+ if self.oproj_tensor_parallel_size is not None:
+ logger.info(
+ f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
+ )
+ if vllm_config.parallel_config.tensor_parallel_size != 1:
+ raise AssertionError(
+ "oproj_tensor_parallel_size is only supported in the pure DP scenario"
+ )
+ if not self.torchair_graph_config.enabled:
+ raise AssertionError(
+ "oproj_tensor_parallel_size is only supported in graph mode"
+ )
+ if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
+ raise AssertionError(
+ "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
+ )
class TorchairGraphConfig:
diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py
index f81d501..07c707e 100644
--- a/vllm_ascend/distributed/parallel_state.py
+++ b/vllm_ascend/distributed/parallel_state.py
@@ -11,7 +11,7 @@ from vllm_ascend.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
-
+_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
@@ -20,6 +20,12 @@ def get_mc2_group() -> GroupCoordinator:
return _MC2
+def get_otp_group() -> GroupCoordinator:
+ assert _OTP is not None, (
+ "output tensor parallel group is not initialized")
+ return _OTP
+
+
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
@@ -74,6 +80,20 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mlp_tp")
+ # If oproj tensor parallel size is set, we will create a group for it.
+ otp_size = get_ascend_config().oproj_tensor_parallel_size
+ if otp_size is not None:
+ group_ranks = []
+ global _OTP
+ num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
+ for i in range(num_oproj_tensor_parallel_groups):
+ ranks = list(range(i * otp_size, (i + 1) * otp_size))
+ group_ranks.append(ranks)
+ _OTP = init_model_parallel_group(group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="otp")
+
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
@@ -117,3 +137,8 @@ def destroy_ascend_model_parallel():
if _LMTP:
_LMTP.destroy()
_LMTP = None
+
+ global _OTP
+ if _OTP:
+ _OTP.destroy()
+ _OTP = None
diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py
index e2f427e..c29837a 100644
--- a/vllm_ascend/ops/linear.py
+++ b/vllm_ascend/ops/linear.py
@@ -18,27 +18,33 @@ limitations under the License.
from typing import Optional, Union
import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch_npu
+from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
-from vllm.distributed import (divide, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- split_tensor_along_last_dim,
- tensor_model_parallel_all_gather,
- tensor_model_parallel_all_reduce)
+from vllm.distributed import divide, split_tensor_along_last_dim
+from vllm.distributed.parallel_state import get_tp_group
+from vllm.lora.utils import LinearBase
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
- LinearBase,
MergedColumnParallelLinear,
- RowParallelLinear)
+ QuantizeMethodBase,
+ RowParallelLinear,
+ UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
-from vllm_ascend.distributed.parallel_state import (
- get_mlp_tensor_model_parallel_rank,
- get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
+from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
+ get_otp_group)
+from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
+ oproj_tp_enable)
+
+_HCOMM_INFO = None
-class AscendMlpColumnParallelLinear(ColumnParallelLinear):
+class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
Use the MLP tensor parallelism group in the MLP module,
@@ -59,15 +65,15 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
*,
return_bias: bool = True,
):
- # Divide the weight matrix along the last dimension.
- if prefix.find("gate_up_proj") != -1:
- self.tp_size = get_mlp_tensor_model_parallel_world_size()
- self.tp_rank = get_mlp_tensor_model_parallel_rank()
- self.enable_mlp_optimze = True
+ self.comm_group = None
+ if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
+ self.comm_group = get_mlp_tp_group()
else:
- self.tp_size = get_tensor_model_parallel_world_size()
- self.tp_rank = get_tensor_model_parallel_rank()
- self.enable_mlp_optimze = False
+ self.comm_group = get_tp_group()
+
+ self.tp_size = self.comm_group.world_size
+ self.tp_rank = self.comm_group.rank_in_group
+
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
@@ -77,14 +83,14 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
- LinearBase.__init__(self,
- input_size,
- output_size,
- skip_bias_add,
- params_dtype,
- quant_config,
- prefix,
- return_bias=return_bias)
+ AscendLinearBase.__init__(self,
+ input_size,
+ output_size,
+ skip_bias_add,
+ params_dtype,
+ quant_config,
+ prefix,
+ return_bias=return_bias)
self.gather_output = gather_output
@@ -114,7 +120,7 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
self.register_parameter("bias", None)
-class AscendMlpRowParallelLinear(RowParallelLinear):
+class AscendRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -134,27 +140,37 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
*,
return_bias: bool = True,
):
- if prefix.find("down_proj") != -1:
- self.tp_size = get_mlp_tensor_model_parallel_world_size()
- self.tp_rank = get_mlp_tensor_model_parallel_rank()
- self.enable_mlp_optimze = True
+ if prefix.find("down_proj") != -1 and mlp_tp_enable():
+ comm_group = get_mlp_tp_group()
+ self.forward_type = "mlp_tp"
+ elif prefix.find("o_proj") != -1 and oproj_tp_enable():
+ comm_group = get_otp_group()
+ self.forward_type = "oproj_tp"
+ elif matmul_allreduce_enable():
+ comm_group = get_tp_group()
+ self.forward_type = "matmul_allreduce"
+ self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
else:
- self.tp_size = get_tensor_model_parallel_world_size()
- self.tp_rank = get_tensor_model_parallel_rank()
- self.enable_mlp_optimze = False
+ comm_group = get_tp_group()
+ self.forward_type = "normal"
+ self.comm_group = comm_group
+
+ self.tp_size = self.comm_group.world_size
+ self.tp_rank = self.comm_group.rank_in_group
+
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
- LinearBase.__init__(self,
- input_size,
- output_size,
- skip_bias_add,
- params_dtype,
- quant_config,
- prefix,
- return_bias=return_bias)
+ AscendLinearBase.__init__(self,
+ input_size,
+ output_size,
+ skip_bias_add,
+ params_dtype,
+ quant_config,
+ prefix,
+ return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -184,61 +200,140 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
else:
self.register_parameter("bias", None)
+ if matmul_allreduce_enable():
+ self.weight_t = self.weight.t()
+
+ @staticmethod
+ def get_hcomm_info(group: ProcessGroup) -> str:
+ """Get the HCCL communication information for the given group."""
+ global _HCOMM_INFO
+ if _HCOMM_INFO is not None:
+ return _HCOMM_INFO
+
+ rank = torch.distributed.get_rank(group)
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ _HCOMM_INFO = group._get_backend(
+ torch.device("npu")).get_hccl_comm_name(global_rank)
+ else:
+ _HCOMM_INFO = group.get_hccl_comm_name(rank)
+ return _HCOMM_INFO
+
def forward(
self,
input_,
+ is_prefill: bool = True,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
- if self.enable_mlp_optimze:
- tp_rank = get_mlp_tensor_model_parallel_rank()
- if self.input_is_parallel:
- input_parallel = input_
- else:
- tp_rank = get_mlp_tensor_model_parallel_rank()
- splitted_input = split_tensor_along_last_dim(
- input_, num_partitions=self.tp_size)
- input_parallel = splitted_input[tp_rank].contiguous()
- # Matrix multiply.
- assert self.quant_method is not None
- # Only fuse bias add into GEMM for rank 0 (this ensures that
- # bias will not get added more than once in TP>1 case)
- bias_ = None if (self.tp_rank > 0
- or self.skip_bias_add) else self.bias
- output_parallel = self.quant_method.apply(self,
- input_parallel,
- bias=bias_)
- output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
- # output = output[:num_tokens,:]
- # dispose_tensor(output_parallel)
+ # Choose different forward function according to the type of TP group
+ if self.forward_type == "oproj_tp":
+ return self._forward_oproj_tp(input_)
+ elif self.forward_type == "mlp_tp":
+ return self._forward_mlp_tp(input_)
+ elif self.forward_type == "matmul_allreduce":
+ return self._forward_matmul_allreduce(input_)
else:
- if self.input_is_parallel:
- input_parallel = input_
- else:
- tp_rank = get_tensor_model_parallel_rank()
- splitted_input = split_tensor_along_last_dim(
- input_, num_partitions=self.tp_size)
- input_parallel = splitted_input[tp_rank].contiguous()
+ return super().forward(input_)
+
+ # enable custom MLP tensor parallel
+ def _forward_mlp_tp(self, input_: torch.Tensor) -> torch.Tensor:
+
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ splitted_input = split_tensor_along_last_dim(
+ input_, num_partitions=self.tp_size)
+ input_parallel = splitted_input[self.tp_rank].contiguous()
+
+ assert self.quant_method is not None
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
+ output_parallel = self.quant_method.apply(self,
+ input_parallel,
+ bias=bias_)
+ output = self.comm_group.reduce_scatter(output_parallel, 0)
- # Matrix multiply.
- assert self.quant_method is not None
- # Only fuse bias add into GEMM for rank 0 (this ensures that
- # bias will not get added more than once in TP>1 case)
- bias_ = None if (self.tp_rank > 0
- or self.skip_bias_add) else self.bias
- output_parallel = self.quant_method.apply(self,
- input_parallel,
- bias=bias_)
- if self.reduce_results and self.tp_size > 1:
- output = tensor_model_parallel_all_reduce(output_parallel)
- else:
- output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
+ if not self.return_bias:
+ return output
+ return output, output_bias
+ # enable custom Oproj tensor parallel
+ def _forward_oproj_tp(
+ self,
+ input_: torch.Tensor,
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
+
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ splitted_input = split_tensor_along_last_dim(
+ input_, num_partitions=self.tp_size)
+ input_parallel = splitted_input[self.tp_rank].contiguous()
+
+ # Prepare tensors for all-to-all communication
+ local_batch_size = input_parallel.size(0)
+ chunk_size = self.input_size_per_partition
+ total_batch_size = local_batch_size * self.tp_size
+
+ # Reshape tensor for efficient cross-device transfer:
+ # [batch, dim] -> [tp_size, batch, chunk] -> flattened
+ send_buf = (input_parallel.reshape(-1,
+ self.tp_size, chunk_size).transpose(
+ 0, 1).contiguous().view(-1))
+
+ # Create receive buffer
+ recv_buf = torch.empty(total_batch_size * chunk_size,
+ dtype=input_parallel.dtype,
+ device=input_parallel.device)
+
+ # Perform all-to-all communication
+ dist.all_to_all_single(recv_buf,
+ send_buf,
+ group=self.comm_group.device_group)
+ input_parallel = recv_buf.view(total_batch_size, chunk_size)
+
+ # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
+ assert self.quant_method is not None
+ output_parallel = self.quant_method.apply(self,
+ input_parallel,
+ bias=bias_)
+
+ # otp-specific: Combine partial results across devices
+ output = self.comm_group.reduce_scatter(output_parallel, dim=0)
+
+ # Handle bias return based on configuration
+ output_bias = self.bias if self.skip_bias_add else None
+ if not self.return_bias:
+ return output
+ return output, output_bias
+
+ def _forward_matmul_allreduce(
+ self, input_: torch.Tensor
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ splitted_input = split_tensor_along_last_dim(
+ input_, num_partitions=self.tp_size)
+ input_parallel = splitted_input[self.tp_rank].contiguous()
+ """Calculate the output tensor of forward by considering
+ fusing communication and computation."""
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
+ if self.reduce_results and self.tp_size > 1:
+ output = torch_npu.npu_mm_all_reduce_base(input_parallel,
+ self.weight_t,
+ self.hcomm_info,
+ bias=bias_)
+ else:
+ output = self.quant_method.apply(self, input_parallel, bias=bias_)
+
+ output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
-class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
+class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
@@ -262,48 +357,85 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
*,
return_bias: bool = True,
):
- self.output_sizes = output_sizes
- if prefix.find("gate_up_proj") != -1:
- self.tp_size = get_mlp_tensor_model_parallel_world_size()
- self.tp_rank = get_mlp_tensor_model_parallel_rank()
- self.enable_mlp_optimze = True
+ self.comm_group = None
+ if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
+ self.comm_group = get_mlp_tp_group()
+ self.forward_type = "mlp_tp"
else:
- self.tp_size = get_tensor_model_parallel_world_size()
- self.tp_rank = get_tensor_model_parallel_rank()
- self.enable_mlp_optimze = False
+ self.comm_group = get_tp_group()
+ self.forward_type = "normal_tp"
+ self.tp_rank = self.comm_group.rank_in_group
+ self.tp_size = self.comm_group.world_size
+
+ self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
- AscendMlpColumnParallelLinear.__init__(self,
- input_size=input_size,
- output_size=sum(output_sizes),
- bias=bias,
- gather_output=gather_output,
- skip_bias_add=skip_bias_add,
- params_dtype=params_dtype,
- quant_config=quant_config,
- prefix=prefix,
- return_bias=return_bias)
+ AscendColumnParallelLinear.__init__(self,
+ input_size=input_size,
+ output_size=sum(output_sizes),
+ bias=bias,
+ gather_output=gather_output,
+ skip_bias_add=skip_bias_add,
+ params_dtype=params_dtype,
+ quant_config=quant_config,
+ prefix=prefix,
+ return_bias=return_bias)
def forward(
self,
input_,
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
+ if self.forward_type == "mlp_tp":
+ return self._forward_mlp_tp(input_)
+ else:
+ return super().forward(input_)
+
+ def _forward_mlp_tp(
+ self,
+ input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
- # self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
# Matrix multiply.
assert self.quant_method is not None
- if self.enable_mlp_optimze:
- input2_ = get_mlp_tp_group().all_gather(input_, 0)
- output = self.quant_method.apply(self, input2_, bias)
- else:
- output_parallel = self.quant_method.apply(self, input_, bias)
- if self.gather_output:
- # All-gather across the partitions.
- output = tensor_model_parallel_all_gather(output_parallel)
- else:
- output = output_parallel
+ input_parallel = get_mlp_tp_group().all_gather(input_, 0)
+ output = self.quant_method.apply(self, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
+
+
+class AscendLinearBase(LinearBase):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ skip_bias_add: bool = False,
+ params_dtype: Optional[torch.dtype] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ *,
+ return_bias: bool = True,
+ disable_tp: bool = False,
+ ):
+ nn.Module.__init__(self)
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.skip_bias_add = skip_bias_add
+ if params_dtype is None:
+ params_dtype = torch.get_default_dtype()
+ self.params_dtype = params_dtype
+ self.quant_config = quant_config
+ self.prefix = prefix
+ if quant_config is None:
+ self.quant_method: Optional[
+ QuantizeMethodBase] = UnquantizedLinearMethod()
+ else:
+ self.quant_method = quant_config.get_quant_method(self,
+ prefix=prefix)
+ self.return_bias = return_bias
+ self.disable_tp = disable_tp
diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py
index 8d206bf..56af25a 100644
--- a/vllm_ascend/patch/worker/patch_common/__init__.py
+++ b/vllm_ascend/patch/worker/patch_common/__init__.py
@@ -16,7 +16,8 @@
#
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
-import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
+import vllm_ascend.patch.worker.patch_common.patch_lora # noqa
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
+import vllm_ascend.patch.worker.patch_common.patch_lora_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
diff --git a/vllm_ascend/patch/worker/patch_common/patch_linear.py b/vllm_ascend/patch/worker/patch_common/patch_linear.py
deleted file mode 100644
index 5690ba8..0000000
--- a/vllm_ascend/patch/worker/patch_common/patch_linear.py
+++ /dev/null
@@ -1,147 +0,0 @@
-"""
-Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
-This file is a part of the vllm-ascend project.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-
-from typing import Optional, Union
-
-import torch
-import torch_npu
-import vllm
-from torch.distributed import ProcessGroup
-from torch.nn.parameter import Parameter
-from vllm.distributed import (get_tensor_model_parallel_rank,
- split_tensor_along_last_dim)
-from vllm.distributed.parallel_state import get_tp_group
-from vllm.logger import logger
-from vllm.model_executor.layers.linear import RowParallelLinear
-
-import vllm_ascend.envs as envs_ascend
-
-_HCOMM_INFO = None
-
-
-class AscendRowParallelLinear(RowParallelLinear):
- """
- AscendRowParallelLinear is a custom implementation of RowParallelLinear
- that overrides the forward method to handle Ascend-specific operations.
- """
-
- def __init__(self, *args, **kwargs):
- """Initialize the AscendRowParallelLinear layer.
-
- Args:
- *args: Variable length argument list.
- **kwargs: Arbitrary keyword arguments.
- """
- tp_group = get_tp_group().device_group
- hcomm_info = self.get_hcomm_info(tp_group)
- self.hcomm_info = hcomm_info
- super().__init__(*args, **kwargs)
- self.weight_t = self.weight.t()
-
- @staticmethod
- def get_hcomm_info(group: ProcessGroup) -> str:
- """Get the HCCL communication information for the given group.
-
- Args:
- group (ProcessGroup): The process group for which to get the HCCL communication info.
-
- Returns:
- str: The HCCL communication name for the given group.
- """
- global _HCOMM_INFO
- if _HCOMM_INFO is not None:
- return _HCOMM_INFO
-
- rank = torch.distributed.get_rank(group)
- if torch.__version__ > "2.0":
- global_rank = torch.distributed.get_global_rank(group, rank)
- _HCOMM_INFO = group._get_backend(
- torch.device("npu")).get_hccl_comm_name(global_rank)
-
- else:
- _HCOMM_INFO = group.get_hccl_comm_name(rank)
- return _HCOMM_INFO
-
- def forward(
- self, input_: torch.Tensor
- ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
- """Forward pass for the AscendRowParallelLinear layer.
-
- Args:
- input_ (torch.Tensor): the input tensor to the layer.
-
- Returns:
- Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
- The output tensor after applying the linear transformation,
- and optionally the bias if `return_bias` is True.
- """
- input_parallel = self.calc_input(input_)
-
- # Matrix multiply.
- assert self.quant_method is not None
- # Only fuse bias add into GEMM for rank 0 (this ensures that
- # bias will not get added more than once in TP>1 case)
- output = self.calc_output(input_parallel)
-
- output_bias = self.bias if self.skip_bias_add else None
-
- if not self.return_bias:
- return output
- return output, output_bias
-
- def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
- """Calculate the input tensor for parallel processing.
-
- Args:
- input_ (torch.Tensor): the input tensor to be processed.
-
- Returns:
- torch.Tensor: The input tensor split along the last dimension
- for tensor model parallelism, or the original input if not parallel.
- """
- if self.input_is_parallel:
- return input_
- tp_rank = get_tensor_model_parallel_rank()
- splitted_input = split_tensor_along_last_dim(
- input_, num_partitions=self.tp_size)
- return splitted_input[tp_rank].contiguous()
-
- def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
- """Calculate the output tensor of forward by considering
- fusing communication and computation.
-
- Args:
- input_parallel (_type_): the input tensor to be processed in parallel.
-
- Returns:
- torch.Tensor: the output tensor after applying the linear transformation
- and optionally handle communication between tensor model parallel ranks.
- """
- bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
- if self.reduce_results and self.tp_size > 1:
- output = torch_npu.npu_mm_all_reduce_base(input_parallel,
- self.weight_t,
- self.hcomm_info,
- bias=bias_)
- else:
- output = self.quant_method.apply(self, input_parallel, bias=bias_)
- return output
-
-
-if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
- logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
- vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
diff --git a/vllm_ascend/patch/worker/patch_common/patch_lora.py b/vllm_ascend/patch/worker/patch_common/patch_lora.py
new file mode 100644
index 0000000..e96f971
--- /dev/null
+++ b/vllm_ascend/patch/worker/patch_common/patch_lora.py
@@ -0,0 +1,15 @@
+import vllm
+from vllm.lora.utils import _all_lora_classes
+
+from vllm_ascend.patch.worker.patch_common.patch_lora_embedding import \
+ AscendVocabParallelEmbeddingWithLoRA
+from vllm_ascend.patch.worker.patch_common.patch_lora_linear import (
+ AscendColumnParallelLinearWithLoRA,
+ AscendMergedColumnParallelLinearWithLoRA, AscendRowParallelLinearWithLoRA)
+
+_all_lora_classes.add(AscendRowParallelLinearWithLoRA)
+_all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
+_all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
+_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
+
+vllm.lora.utils._all_lora_classes = _all_lora_classes
diff --git a/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py b/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py
index 02d5804..eab545b 100644
--- a/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py
+++ b/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py
@@ -1,11 +1,9 @@
from typing import Optional
-import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
-from vllm.lora.utils import _all_lora_classes
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
@@ -22,8 +20,3 @@ class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
-
-
-# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
-_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
-vllm.lora.utils._all_lora_classes = _all_lora_classes
diff --git a/vllm_ascend/patch/worker/patch_common/patch_lora_linear.py b/vllm_ascend/patch/worker/patch_common/patch_lora_linear.py
new file mode 100644
index 0000000..fdfe51d
--- /dev/null
+++ b/vllm_ascend/patch/worker/patch_common/patch_lora_linear.py
@@ -0,0 +1,52 @@
+from typing import Optional
+
+from torch import nn
+from transformers import PretrainedConfig
+from vllm.config import LoRAConfig
+from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
+ MergedColumnParallelLinearWithLoRA,
+ RowParallelLinearWithLoRA)
+
+from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
+ AscendMergedColumnParallelLinear,
+ AscendRowParallelLinear)
+
+
+class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is AscendRowParallelLinear
+
+
+class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is AscendColumnParallelLinear
+
+
+class AscendMergedColumnParallelLinearWithLoRA(
+ MergedColumnParallelLinearWithLoRA):
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is AscendMergedColumnParallelLinear
diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py
index 94ec99a..d4b2c4e 100644
--- a/vllm_ascend/platform.py
+++ b/vllm_ascend/platform.py
@@ -132,7 +132,6 @@ class NPUPlatform(Platform):
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
vllm_config.cache_config.cache_dtype = kv_cache_dtype
-
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py
index 7299dbe..9cf84e8 100644
--- a/vllm_ascend/quantization/quant_config.py
+++ b/vllm_ascend/quantization/quant_config.py
@@ -35,8 +35,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
+from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
+ get_otp_group)
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
-from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
+from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
+ oproj_tp_enable)
from .utils import get_quant_method
@@ -220,9 +223,15 @@ class AscendLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
- tp_rank = get_tensor_model_parallel_rank()
- return self.quant_method.apply(layer, x, bias, tp_rank)
- return self.quant_method.apply(layer, x, bias)
+ if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
+ tp_rank = get_otp_group().rank_in_group
+ elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
+ tp_rank = get_mlp_tp_group().rank_in_group
+ else:
+ tp_rank = get_tensor_model_parallel_rank()
+ else:
+ tp_rank = 0
+ return self.quant_method.apply(layer, x, bias, tp_rank)
class AscendKVCacheMethod(BaseKVCacheMethod):
diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py
index b31549d..ec48b56 100644
--- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py
+++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py
@@ -74,7 +74,7 @@ from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicLinearMethod
-from vllm_ascend.utils import dispose_tensor, npu_prefetch
+from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
@@ -514,11 +514,18 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
- if (config.n_routed_experts is not None
- and self.debug_layer_idx >= config.first_k_dense_replace
- and self.debug_layer_idx % config.moe_layer_freq == 0
- and (ascend_config.torchair_graph_config.enable_multistream_moe
- or self.enable_shared_expert_dp)):
+
+ if oproj_tp_enable():
+ self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+ elif (config.n_routed_experts is not None
+ and self.debug_layer_idx >= config.first_k_dense_replace
+ and self.debug_layer_idx % config.moe_layer_freq == 0
+ and (ascend_config.torchair_graph_config.enable_multistream_moe
+ or self.enable_shared_expert_dp)):
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index f3c1aef..2959310 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -491,9 +491,9 @@ def register_ascend_customop():
from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
- from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
- AscendMlpMergedColumnParallelLinear,
- AscendMlpRowParallelLinear)
+ from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
+ AscendMergedColumnParallelLinear,
+ AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import (
@@ -504,6 +504,12 @@ def register_ascend_customop():
name="SiluAndMul")
CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding,
name="RotaryEmbedding")
+ CustomOp.register_oot(_decorated_op_cls=AscendColumnParallelLinear,
+ name="ColumnParallelLinear")
+ CustomOp.register_oot(_decorated_op_cls=AscendRowParallelLinear,
+ name="RowParallelLinear")
+ CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
+ name="MergedColumnParallelLinear")
CustomOp.register_oot(
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
name="DeepseekScalingRotaryEmbedding")
@@ -513,14 +519,6 @@ def register_ascend_customop():
name="ParallelLMHead")
CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor,
name="LogitsProcessor")
- if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
- CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
- name="ColumnParallelLinear")
- CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear,
- name="RowParallelLinear")
- CustomOp.register_oot(
- _decorated_op_cls=AscendMlpMergedColumnParallelLinear,
- name="MergedColumnParallelLinear")
from vllm_ascend.ops.layernorm import AscendRMSNorm
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
@@ -562,3 +560,15 @@ def get_ascend_soc_version():
def lmhead_tp_enable() -> bool:
return get_ascend_config().lmhead_tensor_parallel_size is not None
+
+
+def oproj_tp_enable() -> bool:
+ return get_ascend_config().oproj_tensor_parallel_size is not None
+
+
+def mlp_tp_enable() -> bool:
+ return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE
+
+
+def matmul_allreduce_enable() -> bool:
+ return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE