[Feat] Add custom Embedding tensor model parallel (#2616)

Similar to #2309 , this PR introduces Embedding tensor model parallel to
achieve decreasing of memory consumption. It support both eager mode and
graph mode.

And this PR refactor module tensor parallel configurations supported in
#2309, #2167, #2120, merge all config into `finegrained_tp_config` in
`additional_config`, including:
`lmhead_tensor_parallel_size`
`oproj_tensor_parallel_size`
`embedding_tensor_parallel_size`
`mlp_tensor_parallel_size`

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
lidenghui1110
2025-12-12 14:41:20 +08:00
committed by GitHub
parent b8a317caac
commit d65fb194d9
9 changed files with 301 additions and 162 deletions

View File

@@ -7,69 +7,27 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_world_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
flashcomm2_o_shared_enabled)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
# Module specific tensor parallel groups
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
_EMBED_TP: Optional[GroupCoordinator] = None
# flashcomm2 specific groups
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
# shared_weight across rank groups
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def get_otp_group() -> GroupCoordinator:
assert _OTP is not None, (
"output tensor parallel group is not initialized")
return _OTP
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP
def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP
def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, (
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP
def get_shared_weight_group() -> GroupCoordinator:
assert _SHARED_WEIGHT is not None, (
"output shared weight parallel group for flashcomm2 is not initialized"
)
return _SHARED_WEIGHT
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, (
"distributed prefill tensor parallel group is not initialized")
return _P_TP
def model_parallel_initialized():
return (_MC2 is not None)
_P_TP: Optional[GroupCoordinator] = None
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
@@ -79,14 +37,16 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
vllm_config = get_current_vllm_config()
global_tp_size = parallel_config.tensor_parallel_size
global_dp_size = parallel_config.data_parallel_size
global_pp_size = parallel_config.pipeline_parallel_size
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.prefill_context_parallel_size *
parallel_config.tensor_parallel_size)
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
global_tp_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio
@@ -98,13 +58,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if pd_head_ratio > 1 and get_current_vllm_config(
).kv_transfer_config.is_kv_producer:
num_head_replica = get_ascend_config().num_head_replica
remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio
remote_tp_size = global_tp_size // pd_tp_ratio
if num_head_replica <= 1:
group_ranks = all_ranks.view(
-1, prefill_tensor_model_parallel_size).unbind(0)
else:
group_ranks = all_ranks.clone().view(
parallel_config.data_parallel_size, -1,
global_dp_size, -1,
num_head_replica) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks.permute(0, 2, 1)
group_ranks = group_ranks.reshape(
@@ -112,8 +72,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
group_ranks = group_ranks.unsqueeze(-1).view(
parallel_config.data_parallel_size, num_head_replica, -1,
alltoall_group_size
global_dp_size, num_head_replica, -1, alltoall_group_size
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.reshape(-1,
alltoall_group_size).unbind(0)
@@ -135,54 +94,72 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
get_world_group().local_rank,
backend,
group_name="mc2")
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
global _MLP_TP
assert _MLP_TP is None, (
"mlp tensor model parallel group is already initialized")
mlp_tp = parallel_config.data_parallel_size
# Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism
# on Ascend hardware. This enables independent TP configurations for three critical components:
all_ranks_mlp_head = torch.arange(world_size).reshape(
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# 1. ** LM Head **:
# The final linear layer that maps hidden states to vocabulary logits.
# Controlled by `lmhead_tensor_parallel_size`.
# message queue broadcaster is only used in tensor model parallel group
_MLP_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mlp_tp")
# 2. ** o_proj **:
# The output projection in attention blocks (e.g., in Multi-Head Attention).
# Controlled by `oproj_tensor_parallel_size`.
# If oproj tensor parallel size is set, we will create a group for it.
otp_size = get_ascend_config().oproj_tensor_parallel_size
if otp_size is not None:
group_ranks = []
global _OTP
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
for i in range(num_oproj_tensor_parallel_groups):
ranks = list(range(i * otp_size, (i + 1) * otp_size))
group_ranks.append(ranks)
_OTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="otp")
# 3. ** Embedding **:
# The token embedding table at the input and/or output of the model.
# Controlled by `embedding_tensor_parallel_size`.
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
group_ranks = []
global _LMTP
num_lmhead_tensor_parallel_groups: int = (world_size //
lmhead_tensor_parallel_size)
for i in range(num_lmhead_tensor_parallel_groups):
ranks = list(
range(i * lmhead_tensor_parallel_size,
(i + 1) * lmhead_tensor_parallel_size))
group_ranks.append(ranks)
_LMTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmheadtp")
# 4. ** MLP **:
# The feed-forward network layers within transformer blocks.
# Controlled by `mlp_tensor_parallel_size`.
_group_cache = {}
def _create_or_get_group(group_size: int,
group_name: str) -> GroupCoordinator:
if group_size is None:
return None
if group_size not in _group_cache:
rank_grid = torch.arange(world_size).reshape(
global_pp_size, global_dp_size, global_tp_size)
num_chunks = global_dp_size // group_size
group_ranks = []
for pp_idx in range(global_pp_size):
stage_ranks = rank_grid[pp_idx] # (dp, tp)
for chunk in range(num_chunks):
for tp_idx in range(global_tp_size):
group = stage_ranks[chunk * group_size:(chunk + 1) *
group_size, tp_idx].tolist()
group_ranks.append(group)
pg = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=group_name)
_group_cache[group_size] = pg
return _group_cache[group_size]
otp_size = get_ascend_config(
).finegrained_tp_config.oproj_tensor_parallel_size
lmhead_tp_size = get_ascend_config(
).finegrained_tp_config.lmhead_tensor_parallel_size
embedding_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
global _OTP, _LMTP, _EMBED_TP
if otp_size > 0:
_OTP = _create_or_get_group(otp_size, "otp")
if lmhead_tp_size > 0:
_LMTP = _create_or_get_group(lmhead_tp_size, "lmheadtp")
if embedding_tp_size > 0:
_EMBED_TP = _create_or_get_group(embedding_tp_size, "emtp")
if mlp_tp_size > 0:
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
@@ -265,14 +242,58 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().world_size
def model_parallel_initialized():
return (_MC2 is not None)
def get_mlp_tensor_model_parallel_rank():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().rank_in_group
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
def get_otp_group() -> GroupCoordinator:
assert _OTP is not None, (
"output tensor parallel group is not initialized")
return _OTP
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP
def get_embed_tp_group() -> GroupCoordinator:
assert _EMBED_TP is not None, ("emtp group is not initialized")
return _EMBED_TP
def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP
def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, (
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP
def get_shared_weight_group() -> GroupCoordinator:
assert _SHARED_WEIGHT is not None, (
"output shared weight parallel group for flashcomm2 is not initialized"
)
return _SHARED_WEIGHT
def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, (
"distributed prefill tensor parallel group is not initialized")
return _P_TP
def destroy_ascend_model_parallel():
@@ -291,6 +312,11 @@ def destroy_ascend_model_parallel():
_LMTP.destroy()
_LMTP = None
global _EMBED_TP
if _EMBED_TP:
_EMBED_TP.destroy()
_EMBED_TP = None
global _OTP
if _OTP:
_OTP.destroy()