[Feat]: Add custom lmhead tensor model parallel (#2309)

### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.

performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |

example

`--additional_config={"lmhead_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
This commit is contained in:
lidenghui1110
2025-08-29 11:41:21 +08:00
committed by GitHub
parent e7ad4a64f4
commit 600b08f754
14 changed files with 458 additions and 22 deletions

View File

@@ -26,7 +26,7 @@ from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul)
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@pytest.fixture
@@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
# 创建一个简单的配置对象
class SimpleConfig:
def __init__(self):
self.vocab_size = 10000
self.hidden_size = 128
config = SimpleConfig()
# 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size)
# 创建模拟输出
mock_output = torch.randn(2, 4, config.hidden_size)
mock_logits = torch.randn(2, 4, config.vocab_size)
# 直接测试logits_processor
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
with patch.object(logits_processor,
"_gather_logits",
return_value=mock_logits):
logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size)