[Feat] Support MLP_TP feature, exclude MOE layer (#4999)

#4257 This PR implements the dense_ffn TP of the first three layers of
the deepseek model, I have refactored this PR and used very little code
to support the implementation of this feature.
This PR adds a function `is_moe_layer` to mlp_tp, which supports MLP TP
in models with both mlp and moe, such as deepseek or chat GLM.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: 子潜 <ziqian@U-DMKXH32D-2015.local>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
zzhxxx
2025-12-18 20:06:53 +08:00
committed by GitHub
parent 5a88e3333b
commit a74a1196c5
3 changed files with 37 additions and 23 deletions

View File

@@ -96,25 +96,11 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mc2")
# Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism
# on Ascend hardware. This enables independent TP configurations for three critical components:
# 1. ** LM Head **:
# The final linear layer that maps hidden states to vocabulary logits.
# Controlled by `lmhead_tensor_parallel_size`.
# 2. ** o_proj **:
# The output projection in attention blocks (e.g., in Multi-Head Attention).
# Controlled by `oproj_tensor_parallel_size`.
# 3. ** Embedding **:
# The token embedding table at the input and/or output of the model.
# Controlled by `embedding_tensor_parallel_size`.
# 4. ** MLP **:
# The feed-forward network layers within transformer blocks.
# Controlled by `mlp_tensor_parallel_size`.
# Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
# 3. Embedding: The token embedding table at the input of the model (`embedding_tensor_parallel_size`)
# 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`)
_group_cache = {}
def _create_or_get_group(group_size: int,
@@ -149,9 +135,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
embedding_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
).finegrained_tp_config.mlp_tensor_parallel_size
global _OTP, _LMTP, _EMBED_TP
global _OTP, _LMTP, _EMBED_TP, _MLP_TP
if otp_size > 0:
_OTP = _create_or_get_group(otp_size, "otp")