[Feat] Support MLP_TP feature, exclude MOE layer (#4999)

#4257 This PR implements the dense_ffn TP of the first three layers of
the deepseek model, I have refactored this PR and used very little code
to support the implementation of this feature.
This PR adds a function `is_moe_layer` to mlp_tp, which supports MLP TP
in models with both mlp and moe, such as deepseek or chat GLM.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: 子潜 <ziqian@U-DMKXH32D-2015.local>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
zzhxxx
2025-12-18 20:06:53 +08:00
committed by GitHub
parent 5a88e3333b
commit a74a1196c5
3 changed files with 37 additions and 23 deletions

View File

@@ -41,6 +41,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2 mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2 mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2 mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.pd_tp_ratio = 2
mock_ascend_config.num_head_replica = 0 mock_ascend_config.num_head_replica = 0

View File

@@ -96,25 +96,11 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend, backend,
group_name="mc2") group_name="mc2")
# Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism # Initialize fine-grained TP process groups on Ascend for four components:
# on Ascend hardware. This enables independent TP configurations for three critical components: # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
# 1. ** LM Head **: # 3. Embedding: The token embedding table at the input of the model (`embedding_tensor_parallel_size`)
# The final linear layer that maps hidden states to vocabulary logits. # 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`)
# Controlled by `lmhead_tensor_parallel_size`.
# 2. ** o_proj **:
# The output projection in attention blocks (e.g., in Multi-Head Attention).
# Controlled by `oproj_tensor_parallel_size`.
# 3. ** Embedding **:
# The token embedding table at the input and/or output of the model.
# Controlled by `embedding_tensor_parallel_size`.
# 4. ** MLP **:
# The feed-forward network layers within transformer blocks.
# Controlled by `mlp_tensor_parallel_size`.
_group_cache = {} _group_cache = {}
def _create_or_get_group(group_size: int, def _create_or_get_group(group_size: int,
@@ -149,9 +135,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
embedding_tp_size = get_ascend_config( embedding_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size ).finegrained_tp_config.embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config( mlp_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size ).finegrained_tp_config.mlp_tensor_parallel_size
global _OTP, _LMTP, _EMBED_TP global _OTP, _LMTP, _EMBED_TP, _MLP_TP
if otp_size > 0: if otp_size > 0:
_OTP = _create_or_get_group(otp_size, "otp") _OTP = _create_or_get_group(otp_size, "otp")

View File

@@ -36,6 +36,8 @@ How to extend a new linear op? Taking column parallel op as an example:
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
""" """
import re
from functools import lru_cache
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@@ -605,7 +607,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
def _get_column_parallel_op( def _get_column_parallel_op(
prefix, layer prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: ) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix: if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer) return MLPColumnParallelOp(layer)
if enable_sp(): if enable_sp():
if "shared_expert" in prefix: if "shared_expert" in prefix:
@@ -629,7 +632,7 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]: SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable(): if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer) return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable(): if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer) return OProjRowParallelOp(layer)
@@ -681,3 +684,27 @@ def get_replicated_op(disable_tp, prefix,
return None return None
return CustomReplicatedOp(layer) return CustomReplicatedOp(layer)
def is_moe_layer(prefix: str) -> bool:
@lru_cache(maxsize=1)
def get_moe_params():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
n_routed_experts = getattr(config, 'n_routed_experts', 0)
first_k_dense_replace = getattr(config, 'first_k_dense_replace',
float('inf'))
moe_layer_freq = getattr(config, 'moe_layer_freq', 1)
return n_routed_experts, first_k_dense_replace, moe_layer_freq
match = re.search(r'layers\.(\d+)\.', prefix)
if match is None:
return False
layer_idx = int(match.group(1))
n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()
return (n_routed_experts is not None and layer_idx >= first_k_dense_replace
and layer_idx % moe_layer_freq == 0)