Support async DeepEP by splitting into two stages (#4995)
This commit is contained in:
@@ -113,7 +113,7 @@ class _DeepEPDispatcherImplBase:
|
|||||||
|
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
def dispatch(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -123,12 +123,18 @@ class _DeepEPDispatcherImplBase:
|
|||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def combine(
|
def dispatch_b(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def combine_b(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -142,7 +148,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
self.src2dst = None
|
self.src2dst = None
|
||||||
|
|
||||||
def dispatch(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -151,12 +157,20 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
):
|
):
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
|
return hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
||||||
|
|
||||||
|
def dispatch_b(
|
||||||
|
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
||||||
|
):
|
||||||
(
|
(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
event,
|
event,
|
||||||
) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
) = self._dispatch_core(
|
||||||
|
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
||||||
|
)
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||||
@@ -187,15 +201,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
expected_m,
|
expected_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dispatch_normal(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
previous_event,
|
||||||
):
|
):
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
|
||||||
|
|
||||||
(
|
(
|
||||||
num_tokens_per_rank,
|
num_tokens_per_rank,
|
||||||
num_tokens_per_rdma_rank,
|
num_tokens_per_rdma_rank,
|
||||||
@@ -279,12 +292,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
)
|
)
|
||||||
return reorder_topk_ids, seg_indptr, gateup_input
|
return reorder_topk_ids, seg_indptr, gateup_input
|
||||||
|
|
||||||
def combine(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
@@ -308,16 +321,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
hidden_states, event = self._combine_normal(
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
output,
|
return output, previous_event
|
||||||
)
|
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
|
||||||
|
|
||||||
|
def combine_b(self, output, previous_event):
|
||||||
|
hidden_states, event = self._combine_core(output, previous_event)
|
||||||
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _combine_normal(self, x: torch.Tensor):
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
|
||||||
|
|
||||||
combined_x, _, event = self.buffer_normal.combine(
|
combined_x, _, event = self.buffer_normal.combine(
|
||||||
x,
|
x,
|
||||||
self.handle,
|
self.handle,
|
||||||
@@ -346,7 +358,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
)
|
)
|
||||||
self.return_recv_hook = return_recv_hook
|
self.return_recv_hook = return_recv_hook
|
||||||
|
|
||||||
def dispatch(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -361,13 +373,33 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
* topk_idx.shape[1]
|
* topk_idx.shape[1]
|
||||||
+ num_experts
|
+ num_experts
|
||||||
) // num_experts
|
) // num_experts
|
||||||
hidden_states, masked_m, event, hook = self._dispatch_low_latency(
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
use_fp8=True,
|
use_fp8=True,
|
||||||
)
|
)
|
||||||
|
return (
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
event,
|
||||||
|
hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dispatch_b(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
event,
|
||||||
|
hook,
|
||||||
|
):
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
@@ -389,7 +421,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
expected_m,
|
expected_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dispatch_low_latency(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -443,22 +475,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
)
|
)
|
||||||
return packed_recv_hidden, packed_recv_count, event, hook
|
return packed_recv_hidden, packed_recv_count, event, hook
|
||||||
|
|
||||||
def combine(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
hidden_states, event, hook = self._combine_low_latency(
|
hidden_states, event, hook = self._combine_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
)
|
)
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
return hidden_states, event, hook
|
||||||
|
|
||||||
|
def combine_b(self, hidden_states, event, hook):
|
||||||
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _combine_low_latency(
|
def _combine_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -514,7 +548,11 @@ class DeepEPDispatcher:
|
|||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||||
|
self.dispatch_a(*args, **kwargs)
|
||||||
|
return self.dispatch_b()
|
||||||
|
|
||||||
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -522,29 +560,45 @@ class DeepEPDispatcher:
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
num_max_dispatch_tokens_per_rank: int = 128,
|
num_max_dispatch_tokens_per_rank: int = 128,
|
||||||
forward_mode: ForwardMode = None,
|
forward_mode: ForwardMode = None,
|
||||||
) -> Tuple:
|
):
|
||||||
return self._get_dispatcher(forward_mode).dispatch(
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||||
)
|
)
|
||||||
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
||||||
|
|
||||||
def combine(
|
def dispatch_b(self):
|
||||||
|
forward_mode, inner_state = self._dispatch_intermediate_state
|
||||||
|
del self._dispatch_intermediate_state
|
||||||
|
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
||||||
|
|
||||||
|
def combine(self, *args, **kwargs) -> Tuple:
|
||||||
|
self.combine_a(*args, **kwargs)
|
||||||
|
return self.combine_b()
|
||||||
|
|
||||||
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
return self._get_dispatcher(forward_mode).combine(
|
inner_state = self._get_impl(forward_mode).combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
)
|
)
|
||||||
|
self._combine_intermediate_state = forward_mode, inner_state
|
||||||
|
|
||||||
def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
def combine_b(self):
|
||||||
|
forward_mode, inner_state = self._combine_intermediate_state
|
||||||
|
del self._combine_intermediate_state
|
||||||
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
||||||
|
|
||||||
|
def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
return self._normal_dispatcher
|
return self._normal_dispatcher
|
||||||
|
|||||||
Reference in New Issue
Block a user