diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 6d8605f77..8bb9224e3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -185,7 +185,6 @@ class DeepEPDispatcher: previous_event=None, num_max_dispatch_tokens_per_rank: int = 128, ) -> Tuple[torch.Tensor, torch.Tensor]: - self.hidden_shape = hidden_states.shape topk_idx = topk_idx.to(torch.int64) # Todo: enable low latency dispatch if True: # not forward_mode.is_decode(): @@ -375,7 +374,7 @@ class DeepEPDispatcher: hidden_states, self.topk_idx, self.topk_weights, self.handle ) self.handle = None - return hidden_states.view(self.hidden_shape) + return hidden_states def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None): combined_x, _, event = self.buffer_normal.combine( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c62dacec9..2214030c8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module): return self.forward_deepep(hidden_states, forward_mode) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states def forward_deepep( self, hidden_states: torch.Tensor, forward_mode: ForwardMode ) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) shared_output = None topk_idx = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device @@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module): if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: