[Feat] Support MLP_TP feature, exclude MOE layer (#4999)

#4257 This PR implements the dense_ffn TP of the first three layers of
the deepseek model, I have refactored this PR and used very little code
to support the implementation of this feature.
This PR adds a function `is_moe_layer` to mlp_tp, which supports MLP TP
in models with both mlp and moe, such as deepseek or chat GLM.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: 子潜 <ziqian@U-DMKXH32D-2015.local>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
zzhxxx
2025-12-18 20:06:53 +08:00
committed by GitHub
parent 5a88e3333b
commit a74a1196c5
3 changed files with 37 additions and 23 deletions

View File

@@ -36,6 +36,8 @@ How to extend a new linear op? Taking column parallel op as an example:
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
"""
import re
from functools import lru_cache
from typing import Optional, Union
import torch
@@ -605,7 +607,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
@@ -629,7 +632,7 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
@@ -681,3 +684,27 @@ def get_replicated_op(disable_tp, prefix,
return None
return CustomReplicatedOp(layer)
def is_moe_layer(prefix: str) -> bool:
@lru_cache(maxsize=1)
def get_moe_params():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
n_routed_experts = getattr(config, 'n_routed_experts', 0)
first_k_dense_replace = getattr(config, 'first_k_dense_replace',
float('inf'))
moe_layer_freq = getattr(config, 'moe_layer_freq', 1)
return n_routed_experts, first_k_dense_replace, moe_layer_freq
match = re.search(r'layers\.(\d+)\.', prefix)
if match is None:
return False
layer_idx = int(match.group(1))
n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()
return (n_routed_experts is not None and layer_idx >= first_k_dense_replace
and layer_idx % moe_layer_freq == 0)