From 77e929a1a2f3f7eadbf0d945c57156c4d9ac482f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 4 Apr 2025 15:32:27 +0800 Subject: [PATCH] Support async DeepEP by splitting into two stages (#4995) --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 118 +++++++++++++----- 1 file changed, 86 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index a2bbbf3bf..08b747b06 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -113,7 +113,7 @@ class _DeepEPDispatcherImplBase: self.handle = None - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -123,12 +123,18 @@ class _DeepEPDispatcherImplBase: ): raise NotImplementedError - def combine( + def dispatch_b(self, *args, **kwargs): + raise NotImplementedError + + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: + ): + raise NotImplementedError + + def combine_b(self, *args, **kwargs): raise NotImplementedError @@ -142,7 +148,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): self.async_finish = async_finish self.src2dst = None - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -151,12 +157,20 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): num_max_dispatch_tokens_per_rank: int, ): topk_idx = topk_idx.to(torch.int64) + previous_event = Buffer.capture() if self.async_finish else None + return hidden_states, topk_idx, topk_weights, num_experts, previous_event + + def dispatch_b( + self, hidden_states, topk_idx, topk_weights, num_experts, previous_event + ): ( hidden_states, topk_idx, topk_weights, event, - ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + ) = self._dispatch_core( + hidden_states, topk_idx, topk_weights, num_experts, previous_event + ) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -187,15 +201,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): expected_m, ) - def _dispatch_normal( + def _dispatch_core( self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, + previous_event, ): - previous_event = Buffer.capture() if self.async_finish else None - ( num_tokens_per_rank, num_tokens_per_rdma_rank, @@ -279,12 +292,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ) return reorder_topk_ids, seg_indptr, gateup_input - def combine( + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: + ): if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk output = torch.empty( @@ -308,16 +321,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): device=hidden_states.device, dtype=hidden_states.dtype, ) - hidden_states, event = self._combine_normal( - output, - ) - event.current_stream_wait() if self.async_finish else () + previous_event = Buffer.capture() if self.async_finish else None + return output, previous_event + def combine_b(self, output, previous_event): + hidden_states, event = self._combine_core(output, previous_event) + event.current_stream_wait() if self.async_finish else () return hidden_states - def _combine_normal(self, x: torch.Tensor): - previous_event = Buffer.capture() if self.async_finish else None - + def _combine_core(self, x: torch.Tensor, previous_event): combined_x, _, event = self.buffer_normal.combine( x, self.handle, @@ -346,7 +358,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ) self.return_recv_hook = return_recv_hook - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -361,13 +373,33 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): * topk_idx.shape[1] + num_experts ) // num_experts - hidden_states, masked_m, event, hook = self._dispatch_low_latency( + hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=True, ) + return ( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ) + + def dispatch_b( + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ): hook() if self.return_recv_hook else event.current_stream_wait() # TODO @@ -389,7 +421,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): expected_m, ) - def _dispatch_low_latency( + def _dispatch_core( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -443,22 +475,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ) return packed_recv_hidden, packed_recv_count, event, hook - def combine( + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: - hidden_states, event, hook = self._combine_low_latency( + ): + hidden_states, event, hook = self._combine_core( hidden_states, topk_idx, topk_weights, ) - hook() if self.return_recv_hook else event.current_stream_wait() + return hidden_states, event, hook + def combine_b(self, hidden_states, event, hook): + hook() if self.return_recv_hook else event.current_stream_wait() return hidden_states - def _combine_low_latency( + def _combine_core( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -514,7 +548,11 @@ class DeepEPDispatcher: **common_kwargs, ) - def dispatch( + def dispatch(self, *args, **kwargs) -> Tuple: + self.dispatch_a(*args, **kwargs) + return self.dispatch_b() + + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -522,29 +560,45 @@ class DeepEPDispatcher: num_experts: int, num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, - ) -> Tuple: - return self._get_dispatcher(forward_mode).dispatch( + ): + inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, num_experts=num_experts, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ) + self._dispatch_intermediate_state = forward_mode, inner_state - def combine( + def dispatch_b(self): + forward_mode, inner_state = self._dispatch_intermediate_state + del self._dispatch_intermediate_state + return self._get_impl(forward_mode).dispatch_b(*inner_state) + + def combine(self, *args, **kwargs) -> Tuple: + self.combine_a(*args, **kwargs) + return self.combine_b() + + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_mode: ForwardMode, - ) -> torch.Tensor: - return self._get_dispatcher(forward_mode).combine( + ): + inner_state = self._get_impl(forward_mode).combine_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, ) + self._combine_intermediate_state = forward_mode, inner_state - def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase: + def combine_b(self): + forward_mode, inner_state = self._combine_intermediate_state + del self._combine_intermediate_state + return self._get_impl(forward_mode).combine_b(*inner_state) + + def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher