### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.
performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |
example
`--additional_config={"lmhead_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
120 lines
4.0 KiB
Python
120 lines
4.0 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from vllm.config import ParallelConfig
|
|
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
|
init_model_parallel_group)
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
|
|
# Currently, mc2 op need their own group coordinator.
|
|
_MC2: Optional[GroupCoordinator] = None
|
|
_MLP_TP: Optional[GroupCoordinator] = None
|
|
|
|
_LMTP: Optional[GroupCoordinator] = None
|
|
|
|
|
|
def get_mc2_group() -> GroupCoordinator:
|
|
assert _MC2 is not None, ("mc2 group is not initialized")
|
|
return _MC2
|
|
|
|
|
|
def get_lmhead_tp_group() -> GroupCoordinator:
|
|
assert _LMTP is not None, (
|
|
"lm head tensor parallel group is not initialized")
|
|
return _LMTP
|
|
|
|
|
|
def get_mlp_tp_group() -> GroupCoordinator:
|
|
assert _MLP_TP is not None, ("mlp group is not initialized")
|
|
return _MLP_TP
|
|
|
|
|
|
def model_parallel_initialized():
|
|
return (_MC2 is not None)
|
|
|
|
|
|
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|
if model_parallel_initialized():
|
|
return
|
|
assert torch.distributed.is_initialized()
|
|
world_size = torch.distributed.get_world_size()
|
|
backend = torch.distributed.get_backend(get_world_group().device_group)
|
|
|
|
# The layout of all ranks: ExternalDP * EP
|
|
# ExternalDP is the data parallel group that is not part of the model,
|
|
# every dp rank can generate independently (in verl integration).
|
|
all_ranks = torch.arange(world_size).reshape(
|
|
-1, parallel_config.data_parallel_size *
|
|
parallel_config.tensor_parallel_size)
|
|
global _MC2
|
|
group_ranks = all_ranks.unbind(0)
|
|
group_ranks = [x.tolist() for x in group_ranks]
|
|
|
|
_MC2 = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
group_name="mc2")
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
|
global _MLP_TP
|
|
assert _MLP_TP is None, (
|
|
"mlp tensor model parallel group is already initialized")
|
|
|
|
mlp_tp = parallel_config.data_parallel_size
|
|
|
|
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
|
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
|
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
|
group_ranks = [x.tolist() for x in group_ranks]
|
|
|
|
# message queue broadcaster is only used in tensor model parallel group
|
|
_MLP_TP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
group_name="mlp_tp")
|
|
|
|
lmhead_tensor_parallel_size = get_ascend_config(
|
|
).lmhead_tensor_parallel_size
|
|
if lmhead_tensor_parallel_size is not None:
|
|
group_ranks = []
|
|
global _LMTP
|
|
num_lmhead_tensor_parallel_groups: int = (world_size //
|
|
lmhead_tensor_parallel_size)
|
|
for i in range(num_lmhead_tensor_parallel_groups):
|
|
ranks = list(
|
|
range(i * lmhead_tensor_parallel_size,
|
|
(i + 1) * lmhead_tensor_parallel_size))
|
|
group_ranks.append(ranks)
|
|
_LMTP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
group_name="lmheadtp")
|
|
|
|
|
|
def get_mlp_tensor_model_parallel_world_size():
|
|
"""Return world size for the tensor model parallel group."""
|
|
return get_mlp_tp_group().world_size
|
|
|
|
|
|
def get_mlp_tensor_model_parallel_rank():
|
|
"""Return world size for the tensor model parallel group."""
|
|
return get_mlp_tp_group().rank_in_group
|
|
|
|
|
|
def destroy_ascend_model_parallel():
|
|
global _MC2
|
|
if _MC2:
|
|
_MC2.destroy()
|
|
_MC2 = None
|
|
|
|
global _MLP_TP
|
|
if _MLP_TP:
|
|
_MLP_TP.destroy()
|
|
_MLP_TP = None
|
|
|
|
global _LMTP
|
|
if _LMTP:
|
|
_LMTP.destroy()
|
|
_LMTP = None
|