Files
xc-llm-ascend/tests/ut/distributed/test_parallel_state.py
lidenghui1110 600b08f754 [Feat]: Add custom lmhead tensor model parallel (#2309)
### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.

performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |

example

`--additional_config={"lmhead_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
2025-08-29 11:41:21 +08:00

45 lines
1.7 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
get_mc2_group, init_ascend_model_parallel)
@pytest.fixture
def parallel_config():
return ParallelConfig(data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=2)
@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
yield
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
assert lmheadtp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None