Add DeepSeek V3/R1 shared experts fusion (#4918)

This commit is contained in:
Xiaoyu Zhang
2025-04-04 16:59:29 +08:00
committed by GitHub
parent 6ff9c6a5e7
commit 924ca7c92c
14 changed files with 536 additions and 36 deletions

View File

@@ -16,12 +16,14 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
import logging
import os
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
@@ -87,6 +89,8 @@ if _is_hip:
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class DeepseekV2MLP(nn.Module):
def __init__(
@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = (
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
),
)
if config.n_shared_experts is not None:
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe
if not global_server_args_dict["enable_deepep_moe"]:
@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None:
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = (
@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict.get("disable_shared_experts_fusion", False)
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5
):
self.n_share_experts_fusion = None
global_server_args_dict["n_share_experts_fusion"] = None
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
elif self.n_share_experts_fusion is None:
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.n_share_experts_fusion = self.tp_size
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for moe_layer in tqdm(
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
),
desc=f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE",
):
for num_repeat in range(self.n_share_experts_fusion):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name].clone(),
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
)
params_dict = dict(self.named_parameters())