Support async in DeepEP (#4610)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -5,7 +5,6 @@ try:
|
||||
except ImportError:
|
||||
use_deepep = False
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -101,6 +100,7 @@ class DeepEPDispatcher:
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
async_finish: bool = False,
|
||||
):
|
||||
self.group = group
|
||||
self.router_topk = router_topk
|
||||
@@ -117,6 +117,7 @@ class DeepEPDispatcher:
|
||||
self.token_probs = None
|
||||
# Handle used for combine operation
|
||||
self.handle = None
|
||||
self.async_finish = async_finish
|
||||
|
||||
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
@@ -182,7 +183,6 @@ class DeepEPDispatcher:
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
forward_mode: ForwardMode,
|
||||
previous_event=None,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
@@ -195,9 +195,7 @@ class DeepEPDispatcher:
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = self.dispatch_normal(
|
||||
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
||||
)
|
||||
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||
self.tokens_per_expert = torch.tensor(
|
||||
num_recv_tokens_per_expert_list,
|
||||
device=hidden_states.device,
|
||||
@@ -213,6 +211,10 @@ class DeepEPDispatcher:
|
||||
)
|
||||
)
|
||||
self.recv_expert_count = recv_expert_count
|
||||
|
||||
if self.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
self.handle = handle
|
||||
self.topk_idx = topk_idx
|
||||
self.topk_weights = topk_weights
|
||||
@@ -235,8 +237,9 @@ class DeepEPDispatcher:
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
previous_event=None,
|
||||
):
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
@@ -247,8 +250,8 @@ class DeepEPDispatcher:
|
||||
topk_idx,
|
||||
num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False,
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -267,8 +270,8 @@ class DeepEPDispatcher:
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False,
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -333,7 +336,7 @@ class DeepEPDispatcher:
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
async_finish=self.async_finish,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
)
|
||||
)
|
||||
@@ -373,16 +376,22 @@ class DeepEPDispatcher:
|
||||
hidden_states, event, hook = self.combine_low_latency(
|
||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||
)
|
||||
|
||||
if self.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
self.handle = None
|
||||
return hidden_states
|
||||
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
x,
|
||||
handle,
|
||||
async_finish=False,
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=False,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
@@ -399,7 +408,7 @@ class DeepEPDispatcher:
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
async_finish=self.async_finish,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -239,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
async_finish=True, # TODO
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user