From 600b08f7542be3409c2c70927c91471e8de33d03 Mon Sep 17 00:00:00 2001
From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com>
Date: Fri, 29 Aug 2025 11:41:21 +0800
Subject: [PATCH] [Feat]: Add custom lmhead tensor model parallel (#2309)
### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.
performance data:
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |
example
`--additional_config={"lmhead_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/de533ab2a14192e461900a4950e2b426d99a6862
---------
Signed-off-by: zzhx1
Co-authored-by: zhangzihang
---
.../configuration/additional_config.md | 1 +
tests/ut/distributed/test_parallel_state.py | 44 ++++
tests/ut/models/test_deepseek_mtp.py | 17 +-
tests/ut/models/test_deepseek_v2.py | 29 ++-
tests/ut/ops/test_vocab_parallel_embedding.py | 62 +++++-
tests/ut/test_ascend_config.py | 13 +-
tests/ut/test_utils.py | 4 +-
.../models/test_torchair_deepseek_mtp.py | 17 +-
vllm_ascend/ascend_config.py | 10 +
vllm_ascend/distributed/parallel_state.py | 31 +++
vllm_ascend/ops/vocab_parallel_embedding.py | 189 +++++++++++++++++-
vllm_ascend/utils.py | 19 +-
vllm_ascend/worker/model_runner_v1.py | 30 ++-
vllm_ascend/worker/mtp_proposer_v1.py | 14 +-
14 files changed, 458 insertions(+), 22 deletions(-)
create mode 100644 tests/ut/distributed/test_parallel_state.py
diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md
index cf4d1dc..cdb0908 100644
--- a/docs/source/user_guide/configuration/additional_config.md
+++ b/docs/source/user_guide/configuration/additional_config.md
@@ -34,6 +34,7 @@ The following table lists the additional configuration options available in vLLM
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
+| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
The details of each config option are as follows:
diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py
new file mode 100644
index 0000000..afc22c8
--- /dev/null
+++ b/tests/ut/distributed/test_parallel_state.py
@@ -0,0 +1,44 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from vllm.config import ParallelConfig
+
+from vllm_ascend.distributed.parallel_state import (
+ _LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
+ get_mc2_group, init_ascend_model_parallel)
+
+
+@pytest.fixture
+def parallel_config():
+ return ParallelConfig(data_parallel_size=2,
+ tensor_parallel_size=2,
+ pipeline_parallel_size=2)
+
+
+@pytest.fixture
+def mock_distributed():
+ with patch('torch.distributed.is_initialized', return_value=True), \
+ patch('torch.distributed.get_world_size', return_value=8), \
+ patch('torch.distributed.get_backend', return_value='nccl'), \
+ patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
+ mock_group.return_value.local_rank = 0
+ mock_group.return_value.device_group = MagicMock()
+ yield
+
+
+def test_init_ascend_model_parallel(mock_distributed, parallel_config):
+ mock_ascend_config = MagicMock()
+ mock_ascend_config.lmhead_tensor_parallel_size = 2
+ with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
+ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
+ patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
+ init_ascend_model_parallel(parallel_config)
+
+ mc2_group = get_mc2_group()
+ assert mc2_group is not None
+ lmheadtp_group = get_lmhead_tp_group()
+ assert lmheadtp_group is not None
+
+ destroy_ascend_model_parallel()
+ assert _MC2 is None
+ assert _LMTP is None
diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py
index 45b6ed5..61fdf98 100644
--- a/tests/ut/models/test_deepseek_mtp.py
+++ b/tests/ut/models/test_deepseek_mtp.py
@@ -31,6 +31,11 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
@@ -83,6 +88,11 @@ class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
predictor = CustomDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
@@ -157,6 +167,11 @@ class TestCustomDeepSeekMTP(PytestBase):
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
return mtp
@@ -177,4 +192,4 @@ class TestCustomDeepSeekMTP(PytestBase):
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
- assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
\ No newline at end of file
+ assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py
index e0c50e8..df14a2a 100644
--- a/tests/ut/models/test_deepseek_v2.py
+++ b/tests/ut/models/test_deepseek_v2.py
@@ -26,7 +26,7 @@ from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
- CustomDeepseekV2SiluAndMul)
+ CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@pytest.fixture
@@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
+
+
+def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
+ # 创建一个简单的配置对象
+ class SimpleConfig:
+
+ def __init__(self):
+ self.vocab_size = 10000
+ self.hidden_size = 128
+
+ config = SimpleConfig()
+
+ # 直接创建lmhead和logits_processor
+ lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
+ logits_processor = LogitsProcessor(config.vocab_size)
+
+ # 创建模拟输出
+ mock_output = torch.randn(2, 4, config.hidden_size)
+ mock_logits = torch.randn(2, 4, config.vocab_size)
+
+ # 直接测试logits_processor
+ with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
+ with patch.object(logits_processor,
+ "_gather_logits",
+ return_value=mock_logits):
+ logits = logits_processor(lmhead, mock_output)
+ assert logits.shape == (2, 4, config.vocab_size)
diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py
index 13ede67..5378b19 100644
--- a/tests/ut/ops/test_vocab_parallel_embedding.py
+++ b/tests/ut/ops/test_vocab_parallel_embedding.py
@@ -18,8 +18,8 @@ from unittest.mock import MagicMock, patch
import torch
-from vllm_ascend.ops.vocab_parallel_embedding import \
- AscendVocabParallelEmbedding
+from vllm_ascend.ops.vocab_parallel_embedding import (
+ AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
@@ -34,7 +34,11 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
- with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
+ mock_group = MagicMock()
+ mock_group.world_size = 2
+ mock_group.rank_in_group = 0
+ with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
+ patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
@@ -174,3 +178,55 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)
+
+
+class TestAscendLogitsProcessor(unittest.TestCase):
+
+ def setUp(self):
+ self.vocab_size = 50
+ self.num_embeddings = 50
+ self.embedding_dim = 10
+ self.org_num_embeddings = 40
+ self.padding_size = 8
+
+ self.mock_group = MagicMock()
+ self.mock_group.world_size = 2
+ self.mock_group.rank_in_group = 0
+ self.mock_ascend_config = MagicMock()
+ self.mock_quant_method = MagicMock()
+ self.mock_quant_method.apply = MagicMock(
+ return_value=torch.randn(1, self.vocab_size))
+ self.patches = [
+ patch("vllm_ascend.ascend_config.get_ascend_config",
+ return_value=self.mock_ascend_config),
+ patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
+ return_value=self.mock_group),
+ patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
+ return_value=True),
+ patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
+ return_value=torch.randn(1, self.vocab_size))
+ ]
+
+ for p in self.patches:
+ p.start()
+
+ def tearDown(self):
+ for p in self.patches:
+ p.stop()
+
+ def test_create_processor(self):
+ processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
+ self.assertEqual(processor.vocab_size, self.vocab_size)
+
+ def test_get_logits(self):
+ processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
+ lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
+ embedding_dim=self.embedding_dim,
+ prefix="lm_head")
+ lmhead.quant_method = self.mock_quant_method
+ lmhead.quant_method.apply = self.mock_quant_method.apply
+ hidden_state = torch.randn(1, self.org_num_embeddings)
+ processor._get_logits(hidden_state, lmhead)
+ self.mock_quant_method.apply.assert_called_once()
diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py
index 49b9abe..622b751 100644
--- a/tests/ut/test_ascend_config.py
+++ b/tests/ut/test_ascend_config.py
@@ -16,7 +16,7 @@
import os
from transformers import PretrainedConfig
-from vllm.config import ModelConfig, VllmConfig
+from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import (_check_torchair_supported,
@@ -75,7 +75,7 @@ class TestAscendConfig(TestBase):
"enabled": True
},
"expert_map_path": "test_expert_map_path",
- "refresh": True
+ "refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
@@ -304,3 +304,12 @@ class TestAscendConfig(TestBase):
"refresh": True
}
init_ascend_config(test_vllm_config)
+
+ with self.assertRaises(AssertionError):
+ test_vllm_config.additional_config = {
+ "lmhead_tensor_parallel_size": 2,
+ "refresh": True
+ }
+ test_vllm_config.parallel_config = ParallelConfig(
+ data_parallel_size=4, tensor_parallel_size=2)
+ init_ascend_config(test_vllm_config)
diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py
index 396f457..0d264c7 100644
--- a/tests/ut/test_utils.py
+++ b/tests/ut/test_utils.py
@@ -289,13 +289,13 @@ class TestUtils(TestBase):
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
- self.assertEqual(mock_customop.register_oot.call_count, 10)
+ self.assertEqual(mock_customop.register_oot.call_count, 12)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
- self.assertEqual(mock_customop.register_oot.call_count, 10)
+ self.assertEqual(mock_customop.register_oot.call_count, 12)
class TestProfileExecuteDuration(TestBase):
diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py
index 1c1e6c7..7aafdfc 100644
--- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py
+++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py
@@ -31,6 +31,11 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
@@ -83,6 +88,11 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
predictor = TorchairDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
@@ -157,6 +167,11 @@ class TestTorchairDeepSeekMTP(PytestBase):
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
+ mocker.patch(
+ "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
+ return_value=None)
+ mocker.patch("vllm_ascend.utils.get_ascend_config",
+ return_value=mocker.Mock())
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
return mtp
@@ -177,4 +192,4 @@ class TestTorchairDeepSeekMTP(PytestBase):
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
- assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
\ No newline at end of file
+ assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py
index 81cf177..2a2ac7b 100644
--- a/vllm_ascend/ascend_config.py
+++ b/vllm_ascend/ascend_config.py
@@ -51,6 +51,16 @@ class AscendConfig:
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.enable_prefetch = additional_config.get("enable_prefetch", False)
+ self.lmhead_tensor_parallel_size = additional_config.get(
+ "lmhead_tensor_parallel_size", None)
+ if self.lmhead_tensor_parallel_size is not None:
+ logger.info(
+ f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
+ )
+ if vllm_config.parallel_config.tensor_parallel_size != 1:
+ raise AssertionError(
+ "lmhead_tensor_parallel_size is only supported in the pure DP scenario"
+ )
class TorchairGraphConfig:
diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py
index db1c5a8..f81d501 100644
--- a/vllm_ascend/distributed/parallel_state.py
+++ b/vllm_ascend/distributed/parallel_state.py
@@ -6,17 +6,26 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
+from vllm_ascend.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
+_LMTP: Optional[GroupCoordinator] = None
+
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
+def get_lmhead_tp_group() -> GroupCoordinator:
+ assert _LMTP is not None, (
+ "lm head tensor parallel group is not initialized")
+ return _LMTP
+
+
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
@@ -65,6 +74,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mlp_tp")
+ lmhead_tensor_parallel_size = get_ascend_config(
+ ).lmhead_tensor_parallel_size
+ if lmhead_tensor_parallel_size is not None:
+ group_ranks = []
+ global _LMTP
+ num_lmhead_tensor_parallel_groups: int = (world_size //
+ lmhead_tensor_parallel_size)
+ for i in range(num_lmhead_tensor_parallel_groups):
+ ranks = list(
+ range(i * lmhead_tensor_parallel_size,
+ (i + 1) * lmhead_tensor_parallel_size))
+ group_ranks.append(ranks)
+ _LMTP = init_model_parallel_group(group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="lmheadtp")
+
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
@@ -86,3 +112,8 @@ def destroy_ascend_model_parallel():
if _MLP_TP:
_MLP_TP.destroy()
_MLP_TP = None
+
+ global _LMTP
+ if _LMTP:
+ _LMTP.destroy()
+ _LMTP = None
diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py
index 05b08a4..7ad35dc 100644
--- a/vllm_ascend/ops/vocab_parallel_embedding.py
+++ b/vllm_ascend/ops/vocab_parallel_embedding.py
@@ -15,15 +15,108 @@
# limitations under the License.
#
-from typing import Tuple
+from typing import Optional, Tuple
import torch
-from vllm.distributed import tensor_model_parallel_all_reduce
-from vllm.model_executor.layers.vocab_parallel_embedding import \
- VocabParallelEmbedding
+from torch import nn
+from torch.nn.parameter import Parameter
+from vllm.distributed import divide, tensor_model_parallel_all_reduce
+from vllm.distributed.parallel_state import get_tp_group
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
+ VocabParallelEmbedding, pad_vocab_size)
+from vllm.model_executor.utils import set_weight_attrs
+
+from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
+from vllm_ascend.utils import lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
+ """
+ Register VocabParallelEmbedding as a custom op for Ascend.
+ AscendVocabParallelEmbedding support different communication parallel groups
+ Added the feature of lmheadTP in pure dp scenario
+ """
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ params_dtype: Optional[torch.dtype] = None,
+ org_num_embeddings: Optional[int] = None,
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ nn.Module.__init__(self)
+
+ if lmhead_tp_enable() and prefix.find("lm_head") != -1:
+ self.comm_group = get_lmhead_tp_group()
+ else:
+ self.comm_group = get_tp_group()
+
+ self.tp_size = self.comm_group.world_size
+ self.tp_rank = self.comm_group.rank_in_group
+
+ self.num_embeddings = num_embeddings
+ self.padding_size = padding_size
+ self.org_vocab_size = org_num_embeddings or num_embeddings
+ num_added_embeddings = num_embeddings - self.org_vocab_size
+ self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
+ self.padding_size)
+ self.num_embeddings_padded = pad_vocab_size(
+ self.org_vocab_size_padded + num_added_embeddings,
+ self.padding_size)
+ assert self.org_vocab_size_padded <= self.num_embeddings_padded
+
+ self.shard_indices = self._get_indices(self.num_embeddings_padded,
+ self.org_vocab_size_padded,
+ self.num_embeddings,
+ self.org_vocab_size,
+ self.tp_rank, self.tp_size)
+ self.embedding_dim = embedding_dim
+ quant_method = None
+ if quant_config is not None:
+ quant_method = quant_config.get_quant_method(self, prefix=prefix)
+ if quant_method is None:
+ quant_method = UnquantizedEmbeddingMethod()
+
+ # If we are making an embedding layer, then our quantization linear
+ # method must implement the embedding operation. If we are another
+ # layer type like ParallelLMHead, this is not important.
+ is_embedding_layer = type(self) is VocabParallelEmbedding
+ quant_method_implements_embedding = method_has_implemented_embedding(
+ type(quant_method))
+ if is_embedding_layer and not quant_method_implements_embedding:
+ raise NotImplementedError(
+ f"The class {type(quant_method).__name__} must implement "
+ "the 'embedding' method, see UnquantizedEmbeddingMethod.")
+
+ self.quant_method: QuantizeMethodBase = quant_method
+
+ if params_dtype is None:
+ params_dtype = torch.get_default_dtype()
+ # Divide the weight matrix along the vocaburaly dimension.
+ self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
+ self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
+ self.tp_size)
+ assert (self.shard_indices.num_elements_padded ==
+ self.num_embeddings_per_partition)
+ self.num_org_embeddings_per_partition = (
+ self.shard_indices.org_vocab_end_index -
+ self.shard_indices.org_vocab_start_index)
+ self.num_added_embeddings_per_partition = (
+ self.shard_indices.added_vocab_end_index -
+ self.shard_indices.added_vocab_start_index)
+
+ self.quant_method.create_weights(self,
+ self.embedding_dim,
+ [self.num_embeddings_per_partition],
+ self.embedding_dim,
+ self.num_embeddings_padded,
+ params_dtype=params_dtype,
+ weight_loader=self.weight_loader)
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
@@ -71,3 +164,91 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
+
+
+class AscendParallelLMHead(ParallelLMHead):
+ """
+ Register ParallelLMHead as a custom op for Ascend."""
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ bias: bool = False,
+ params_dtype: Optional[torch.dtype] = None,
+ org_num_embeddings: Optional[int] = None,
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ AscendVocabParallelEmbedding.__init__(self, num_embeddings,
+ embedding_dim, params_dtype,
+ org_num_embeddings, padding_size,
+ quant_config, prefix)
+
+ self.quant_config = quant_config
+ if bias:
+ self.bias = Parameter(
+ torch.empty(self.num_embeddings_per_partition,
+ dtype=params_dtype))
+ set_weight_attrs(self.bias, {
+ "output_dim": 0,
+ "weight_loader": self.weight_loader,
+ })
+ else:
+ self.register_parameter("bias", None)
+
+
+class AscendLogitsProcessor(LogitsProcessor):
+ """
+ Register LogitsProcessor as a custom op for Ascend.
+ Added the feature of lmheadTP in pure dp scenario
+ """
+
+ def _get_logits(
+ self,
+ hidden_states: torch.Tensor,
+ lm_head: AscendParallelLMHead,
+ embedding_bias: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ if lmhead_tp_enable():
+ return self._get_logits_lmheadtp(hidden_states, lm_head,
+ embedding_bias)
+ else:
+ return self._get_logits_normal(hidden_states, lm_head,
+ embedding_bias)
+
+ def _get_logits_lmheadtp(
+ self,
+ hidden_states: torch.Tensor,
+ lm_head: AscendParallelLMHead,
+ embedding_bias: Optional[torch.Tensor],
+ ) -> Optional[torch.Tensor]:
+ # Gather hidden states from all devices in tensor parallel group
+ gathered_hidden_states = get_lmhead_tp_group().all_gather(
+ hidden_states, dim=0)
+ local_logits = lm_head.quant_method.apply(lm_head,
+ gathered_hidden_states,
+ bias=embedding_bias)
+ # Gather logits for tensor parallel
+ logits = get_lmhead_tp_group().all_to_all(local_logits)
+ # Remove paddings in vocab (if any)
+ if logits is not None:
+ logits = logits[..., :self.org_vocab_size]
+ return logits
+
+ def _get_logits_normal(
+ self,
+ hidden_states: torch.Tensor,
+ lm_head: AscendParallelLMHead,
+ embedding_bias: Optional[torch.Tensor],
+ ) -> Optional[torch.Tensor]:
+ local_logits = lm_head.quant_method.apply(lm_head,
+ hidden_states,
+ bias=embedding_bias)
+ # Gather logits for tensor parallel
+ logits = self._gather_logits(local_logits)
+
+ # Remove paddings in vocab (if any)
+ if logits is not None:
+ logits = logits[..., :self.org_vocab_size]
+
+ return logits
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index a99a491..adab490 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -33,6 +33,7 @@ from torch_npu.npu.streams import Event
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
+from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -489,6 +490,9 @@ def register_ascend_customop():
AscendMlpRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
+ from vllm_ascend.ops.vocab_parallel_embedding import (
+ AscendLogitsProcessor, AscendParallelLMHead,
+ AscendVocabParallelEmbedding)
CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU")
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
name="SiluAndMul")
@@ -497,6 +501,12 @@ def register_ascend_customop():
CustomOp.register_oot(
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
name="DeepseekScalingRotaryEmbedding")
+ CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
+ name="VocabParallelEmbedding")
+ CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead,
+ name="ParallelLMHead")
+ CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor,
+ name="LogitsProcessor")
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
name="ColumnParallelLinear")
@@ -512,11 +522,6 @@ def register_ascend_customop():
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
- from vllm_ascend.ops.vocab_parallel_embedding import \
- AscendVocabParallelEmbedding
- CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
- name="VocabParallelEmbedding")
-
# NOTE: Keep this at last to ensure all custom actions are registered
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
@@ -547,3 +552,7 @@ def get_ascend_soc_version():
global _ascend_soc_version
assert _ascend_soc_version is not None
return _ascend_soc_version
+
+
+def lmhead_tp_enable() -> bool:
+ return get_ascend_config().lmhead_tensor_parallel_size is not None
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
index 6930a1c..15effb7 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -90,7 +90,7 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
- vllm_version_is)
+ lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -1277,6 +1277,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
+ if lmhead_tp_enable():
+ max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
+ logits_indices = nn.functional.pad(
+ logits_indices,
+ (0, max_num_reqs_across_dp - logits_indices.shape[0]))
+
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
@@ -1734,11 +1740,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
+ if lmhead_tp_enable() and logits is not None:
+ logits = logits[:self.input_batch.num_reqs]
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
+ if lmhead_tp_enable() and logits is not None:
+ logits = logits[:len(spec_decode_metadata.logits_indices)]
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
@@ -2081,6 +2091,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
+ need_dummy_logits = (not self.in_profile_run
+ and lmhead_tp_enable())
+
+ if need_dummy_logits:
+ max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
+ dummy_indices = torch.zeros(max_num_reqs_across_dp,
+ dtype=torch.int32)
+
+ def dummy_compute_logits(hidden_states):
+ return self.model.compute_logits(
+ hidden_states[dummy_indices], None)
+
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
@@ -2097,6 +2119,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
+ if need_dummy_logits:
+ dummy_compute_logits(hidden_states)
+
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
self.drafter.dummy_run(
@@ -2105,7 +2130,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
skip_attn=True,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
-
+ if need_dummy_logits:
+ dummy_compute_logits(hidden_states)
return hidden_states
@contextmanager
diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py
index 120b17a..848da93 100644
--- a/vllm_ascend/worker/mtp_proposer_v1.py
+++ b/vllm_ascend/worker/mtp_proposer_v1.py
@@ -19,7 +19,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
-from vllm_ascend.utils import ProfileExecuteDuration
+from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
class MtpProposer:
@@ -235,8 +235,20 @@ class MtpProposer:
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
+
+ num_indices = last_token_indices.shape[0]
+ if lmhead_tp_enable():
+ if not self.runner.with_prefill:
+ max_num_reqs_across_dp = num_input_tokens
+ else:
+ max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
+ last_token_indices = nn.functional.pad(
+ last_token_indices, (0, max_num_reqs_across_dp - num_indices))
+
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
+ if lmhead_tp_enable() and num_indices < logits.shape[0]:
+ logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
# [batch_size, 1]