[Feature] Integrate DeepEP into SGLang (#4232)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: Xuting Zhou <xutingz@nvidia.com>
This commit is contained in:
Jinyan Chen
2025-03-19 23:16:31 +08:00
committed by GitHub
parent f9c53cbb42
commit f44db16c8e
12 changed files with 1228 additions and 35 deletions

165
python/sglang/srt/models/deepseek_v2.py Executable file → Normal file
View File

@@ -26,6 +26,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
@@ -47,8 +48,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
@@ -65,7 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
@@ -87,6 +90,8 @@ class DeepseekV2MLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -95,6 +100,8 @@ class DeepseekV2MLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.down_proj = RowParallelLinear(
intermediate_size,
@@ -103,6 +110,8 @@ class DeepseekV2MLP(nn.Module):
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
if hidden_act != "silu":
raise ValueError(
@@ -167,7 +176,11 @@ class DeepseekV2MoE(nn.Module):
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
@@ -184,16 +197,59 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
# disable tp for shared experts when enable deepep moe
if not global_server_args_dict["enable_deepep_moe"]:
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
)
else:
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
tp_rank=0,
tp_size=1,
)
if global_server_args_dict["enable_deepep_moe"]:
self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
self.num_expert_group = config.n_group
self.correction_bias = (
self.gate.e_score_correction_bias.data
if self.gate.e_score_correction_bias is not None
else None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.deepep_dispatcher = DeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=config.n_routed_experts,
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
)
def forward(
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
@@ -208,6 +264,59 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if forward_mode is not None and not forward_mode.is_idle():
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)
if self.tp_size > 1:
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = (
self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
self.num_experts,
forward_mode,
)
)
final_hidden_states = (
self.experts(
hidden_states=recv_hidden_states,
tokens_per_expert=tokens_per_expert,
forward_mode=forward_mode,
)
* self.routed_scaling_factor
)
if self.tp_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, forward_mode
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim)
@@ -959,15 +1068,25 @@ class DeepseekV2DecoderLayer(nn.Module):
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
if global_server_args_dict["enable_deepep_moe"] and isinstance(
self.mlp, DeepseekV2MoE
):
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
return hidden_states, residual
else:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
@@ -1099,7 +1218,11 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",