diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5955332f5..e422a5038 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -194,6 +194,14 @@ class MoEGate(nn.Module): return logits +def is_non_idle_and_non_empty(forward_mode, hidden_states): + return ( + (forward_mode is not None) + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + ) + + class DeepseekV2MoE(nn.Module): def __init__( @@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module): ), ) + self.top_k = config.num_experts_per_tok + if global_server_args_dict["enable_deepep_moe"]: # TODO: we will support tp < ep in the future self.ep_size = get_tensor_model_parallel_world_size() self.num_experts = config.n_routed_experts - self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group self.num_expert_group = config.n_group @@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module): return_recv_hook=True, ) + @property + def _enable_deepep_moe(self): + return global_server_args_dict["enable_deepep_moe"] + def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: - if not global_server_args_dict["enable_deepep_moe"]: - return self.forward_normal(hidden_states) - else: - return self.forward_deepep(hidden_states, forward_batch) - - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - shared_output = self._forward_shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) - final_hidden_states *= self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states - - def forward_deepep( - self, hidden_states: torch.Tensor, forward_batch: ForwardBatch - ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode - shared_output = None - if ( - forward_mode is not None - and not forward_mode.is_idle() - and hidden_states.shape[0] > 0 + if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( + forward_mode, hidden_states ): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - shared_output = self._forward_shared_experts(hidden_states) + else: + router_logits = None + + if (self.n_share_experts_fusion == 0) and ( + (not self._enable_deepep_moe) + or is_non_idle_and_non_empty(forward_mode, hidden_states) + ): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + + if self._enable_deepep_moe and (router_logits is not None): topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module): topk_weights = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) - if self.ep_size > 1: + + if self._enable_deepep_moe and (self.ep_size > 1): # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( hidden_states, @@ -357,36 +356,41 @@ class DeepseekV2MoE(nn.Module): topk_weights, forward_mode=forward_mode, ) - final_hidden_states = self.experts( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - num_recv_tokens_per_expert=num_recv_tokens_per_expert, - forward_mode=forward_mode, - ) - if self.ep_size > 1: + + if self._enable_deepep_moe: + final_hidden_states = self.experts( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + num_recv_tokens_per_expert=num_recv_tokens_per_expert, + forward_mode=forward_mode, + ) + else: + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self._enable_deepep_moe and (self.ep_size > 1): final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states, topk_idx, topk_weights, forward_mode, ) + final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - return final_hidden_states + if (not self._enable_deepep_moe) and (self.tp_size > 1): + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - def _forward_shared_experts(self, hidden_states): - if self.n_share_experts_fusion == 0: - return self.shared_experts(hidden_states) - else: - return None + return final_hidden_states def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: