Support async in DeepEP (#4610)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -5,7 +5,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
use_deepep = False
|
use_deepep = False
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -101,6 +100,7 @@ class DeepEPDispatcher:
|
|||||||
num_local_experts: int = None,
|
num_local_experts: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
params_dtype: torch.dtype = None,
|
params_dtype: torch.dtype = None,
|
||||||
|
async_finish: bool = False,
|
||||||
):
|
):
|
||||||
self.group = group
|
self.group = group
|
||||||
self.router_topk = router_topk
|
self.router_topk = router_topk
|
||||||
@@ -117,6 +117,7 @@ class DeepEPDispatcher:
|
|||||||
self.token_probs = None
|
self.token_probs = None
|
||||||
# Handle used for combine operation
|
# Handle used for combine operation
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
self.async_finish = async_finish
|
||||||
|
|
||||||
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||||
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||||
@@ -182,7 +183,6 @@ class DeepEPDispatcher:
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
previous_event=None,
|
|
||||||
num_max_dispatch_tokens_per_rank: int = 128,
|
num_max_dispatch_tokens_per_rank: int = 128,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
@@ -195,9 +195,7 @@ class DeepEPDispatcher:
|
|||||||
num_recv_tokens_per_expert_list,
|
num_recv_tokens_per_expert_list,
|
||||||
handle,
|
handle,
|
||||||
event,
|
event,
|
||||||
) = self.dispatch_normal(
|
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||||
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
|
||||||
)
|
|
||||||
self.tokens_per_expert = torch.tensor(
|
self.tokens_per_expert = torch.tensor(
|
||||||
num_recv_tokens_per_expert_list,
|
num_recv_tokens_per_expert_list,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
@@ -213,6 +211,10 @@ class DeepEPDispatcher:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.recv_expert_count = recv_expert_count
|
self.recv_expert_count = recv_expert_count
|
||||||
|
|
||||||
|
if self.async_finish:
|
||||||
|
event.current_stream_wait()
|
||||||
|
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.topk_idx = topk_idx
|
self.topk_idx = topk_idx
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
@@ -235,8 +237,9 @@ class DeepEPDispatcher:
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
previous_event=None,
|
|
||||||
):
|
):
|
||||||
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
|
|
||||||
(
|
(
|
||||||
num_tokens_per_rank,
|
num_tokens_per_rank,
|
||||||
num_tokens_per_rdma_rank,
|
num_tokens_per_rdma_rank,
|
||||||
@@ -247,8 +250,8 @@ class DeepEPDispatcher:
|
|||||||
topk_idx,
|
topk_idx,
|
||||||
num_experts,
|
num_experts,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
async_finish=False,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -267,8 +270,8 @@ class DeepEPDispatcher:
|
|||||||
is_token_in_rank=is_token_in_rank,
|
is_token_in_rank=is_token_in_rank,
|
||||||
num_tokens_per_expert=num_tokens_per_expert,
|
num_tokens_per_expert=num_tokens_per_expert,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
async_finish=False,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -333,7 +336,7 @@ class DeepEPDispatcher:
|
|||||||
topk_idx,
|
topk_idx,
|
||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
async_finish=False,
|
async_finish=self.async_finish,
|
||||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -373,16 +376,22 @@ class DeepEPDispatcher:
|
|||||||
hidden_states, event, hook = self.combine_low_latency(
|
hidden_states, event, hook = self.combine_low_latency(
|
||||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.async_finish:
|
||||||
|
event.current_stream_wait()
|
||||||
|
|
||||||
self.handle = None
|
self.handle = None
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
||||||
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
|
|
||||||
combined_x, _, event = self.buffer_normal.combine(
|
combined_x, _, event = self.buffer_normal.combine(
|
||||||
x,
|
x,
|
||||||
handle,
|
handle,
|
||||||
async_finish=False,
|
async_finish=self.async_finish,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
allocate_on_comm_stream=False,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
)
|
)
|
||||||
return combined_x, event
|
return combined_x, event
|
||||||
|
|
||||||
@@ -399,7 +408,7 @@ class DeepEPDispatcher:
|
|||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
handle,
|
handle,
|
||||||
async_finish=False,
|
async_finish=self.async_finish,
|
||||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
|
async_finish=True, # TODO
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user