diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8017cefa4..c0088faca 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -324,6 +324,104 @@ class DeepseekV2MoE(nn.Module): if name not in ["correction_bias"] ] + def forward( + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + ) -> torch.Tensor: + if not self._enable_deepep_moe: + return self.forward_normal(hidden_states) + else: + return self.forward_deepep(hidden_states, forward_batch) + + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + shared_output = self._forward_shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + final_hidden_states *= self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states + + def forward_deepep( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + forward_mode = forward_batch.forward_mode + shared_output = None + if is_non_idle_and_non_empty(forward_mode, hidden_states): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) + topk_weights, topk_idx = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=forward_batch.num_token_non_padded, + ) + else: + topk_idx = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + topk_weights = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + if self.ep_size > 1: + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value + ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + num_recv_tokens_per_expert, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_mode=forward_mode, + ) + final_hidden_states = self.experts( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + num_recv_tokens_per_expert=num_recv_tokens_per_expert, + forward_mode=forward_mode, + ) + if self.ep_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + hidden_states=final_hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_mode=forward_mode, + ) + final_hidden_states *= self.routed_scaling_factor + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + return final_hidden_states + + def _forward_shared_experts(self, hidden_states): + if self.n_share_experts_fusion == 0: + return self.shared_experts(hidden_states) + else: + return None + def op_gate(self, state): if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( state.forward_batch.forward_mode, state.hidden_states_mlp_input @@ -1353,17 +1451,29 @@ class DeepseekV2DecoderLayer(nn.Module): residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, ) -> torch.Tensor: - return execute_operations( - inputs=dict( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - residual=residual, - zero_allocator=zero_allocator, - ), - operations=compute_layer_operations(self), + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch ) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + def op_comm_prepare_attn( self, state,