Small refactor DeepEPDispatcher into subclasses (#4994)
This commit is contained in:
@@ -23,7 +23,7 @@ _buffer_normal = None
|
|||||||
_buffer_low_latency = None
|
_buffer_low_latency = None
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
||||||
"""
|
"""
|
||||||
Copy from DeepEP example usage in model inference prefilling.
|
Copy from DeepEP example usage in model inference prefilling.
|
||||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
||||||
@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
|||||||
return _buffer_normal
|
return _buffer_normal
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_low_latency(
|
def _get_buffer_low_latency(
|
||||||
group: dist.ProcessGroup,
|
group: dist.ProcessGroup,
|
||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
hidden: int,
|
hidden: int,
|
||||||
@@ -85,24 +85,16 @@ def get_buffer_low_latency(
|
|||||||
return _buffer_low_latency
|
return _buffer_low_latency
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatcher:
|
class _DeepEPDispatcherImplBase:
|
||||||
"""
|
|
||||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
|
||||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group: torch.distributed.ProcessGroup,
|
group: torch.distributed.ProcessGroup,
|
||||||
router_topk: int,
|
router_topk: int,
|
||||||
permute_fusion: bool = False,
|
permute_fusion: bool,
|
||||||
num_experts: int = None,
|
num_experts: int,
|
||||||
num_local_experts: int = None,
|
num_local_experts: int,
|
||||||
hidden_size: int = None,
|
hidden_size: int,
|
||||||
params_dtype: torch.dtype = None,
|
params_dtype: torch.dtype,
|
||||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
|
||||||
async_finish: bool = False,
|
|
||||||
return_recv_hook: bool = False,
|
|
||||||
):
|
):
|
||||||
if not use_deepep:
|
if not use_deepep:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -119,115 +111,71 @@ class DeepEPDispatcher:
|
|||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
self.params_bytes = 2
|
self.params_bytes = 2
|
||||||
|
|
||||||
self.deepep_mode = deepep_mode
|
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
if self.deepep_mode.enable_normal():
|
|
||||||
self.buffer_normal = get_buffer_normal(
|
|
||||||
self.group, self.hidden_size * self.params_bytes
|
|
||||||
)
|
|
||||||
self.async_finish = async_finish
|
|
||||||
self.src2dst = None
|
|
||||||
if self.deepep_mode.enable_low_latency():
|
|
||||||
"""
|
|
||||||
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
|
||||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
||||||
"""
|
|
||||||
# TODO(ch-wan): allow users to set this value
|
|
||||||
self.num_max_dispatch_tokens_per_rank = 128
|
|
||||||
self.buffer_low_latency = get_buffer_low_latency(
|
|
||||||
self.group,
|
|
||||||
self.num_max_dispatch_tokens_per_rank,
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_experts,
|
|
||||||
)
|
|
||||||
self.return_recv_hook = return_recv_hook
|
|
||||||
|
|
||||||
def deepep_permute(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
fp8_dtype: Optional[torch.dtype] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
use_block_quant: bool = False,
|
|
||||||
):
|
|
||||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
|
||||||
topk_idx, self.num_experts
|
|
||||||
)
|
|
||||||
num_total_tokens = reorder_topk_ids.numel()
|
|
||||||
gateup_input = torch.empty(
|
|
||||||
(int(num_total_tokens), hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=(
|
|
||||||
fp8_dtype
|
|
||||||
if (use_fp8_w8a8 and not use_block_quant)
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# PreReorder
|
|
||||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
|
||||||
hidden_states,
|
|
||||||
gateup_input,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
None,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, gateup_input
|
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
num_max_dispatch_tokens_per_rank: int = 128,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
forward_mode: ForwardMode = None,
|
):
|
||||||
) -> Tuple:
|
raise NotImplementedError
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
|
||||||
reorder_topk_ids = torch.empty(
|
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
masked_m = torch.empty(
|
|
||||||
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
expected_m = 0
|
|
||||||
|
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
def combine(
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
self,
|
||||||
(
|
hidden_states: torch.Tensor,
|
||||||
hidden_states,
|
topk_idx: torch.Tensor,
|
||||||
topk_idx,
|
topk_weights: torch.Tensor,
|
||||||
topk_weights,
|
) -> torch.Tensor:
|
||||||
event,
|
raise NotImplementedError
|
||||||
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
|
||||||
if hidden_states.shape[0] > 0:
|
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
def __init__(self, async_finish: bool, **kwargs):
|
||||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
super().__init__(**kwargs)
|
||||||
)
|
|
||||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
self.buffer_normal = _get_buffer_normal(
|
||||||
expected_m = (
|
self.group, self.hidden_size * self.params_bytes
|
||||||
hidden_states.shape[0]
|
)
|
||||||
* self.buffer_low_latency.group_size
|
self.async_finish = async_finish
|
||||||
* topk_idx.shape[1]
|
self.src2dst = None
|
||||||
+ num_experts
|
|
||||||
) // num_experts
|
def dispatch(
|
||||||
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx,
|
topk_idx: torch.Tensor,
|
||||||
num_max_dispatch_tokens_per_rank,
|
topk_weights: torch.Tensor,
|
||||||
num_experts,
|
num_experts: int,
|
||||||
use_fp8=True,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
):
|
||||||
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
|
(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
event,
|
||||||
|
) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||||
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
if hidden_states.shape[0] > 0:
|
||||||
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||||
|
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||||
)
|
)
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# masked_m = torch.empty(
|
||||||
|
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
# )
|
||||||
|
# expected_m = 0
|
||||||
|
masked_m = expected_m = None
|
||||||
|
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -239,7 +187,7 @@ class DeepEPDispatcher:
|
|||||||
expected_m,
|
expected_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_normal(
|
def _dispatch_normal(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -292,7 +240,156 @@ class DeepEPDispatcher:
|
|||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_low_latency(
|
def _deepep_permute(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
fp8_dtype: Optional[torch.dtype] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_block_quant: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||||
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||||
|
topk_idx, self.num_experts
|
||||||
|
)
|
||||||
|
num_total_tokens = reorder_topk_ids.numel()
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
(int(num_total_tokens), hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=(
|
||||||
|
fp8_dtype
|
||||||
|
if (use_fp8_w8a8 and not use_block_quant)
|
||||||
|
else hidden_states.dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# PreReorder
|
||||||
|
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||||
|
hidden_states,
|
||||||
|
gateup_input,
|
||||||
|
self.src2dst,
|
||||||
|
topk_idx,
|
||||||
|
None,
|
||||||
|
self.router_topk,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, seg_indptr, gateup_input
|
||||||
|
|
||||||
|
def combine(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hidden_states.shape[0] > 0:
|
||||||
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
self.src2dst,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
self.router_topk,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch.zeros(
|
||||||
|
(0, hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states, event = self._combine_normal(
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _combine_normal(self, x: torch.Tensor):
|
||||||
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
|
|
||||||
|
combined_x, _, event = self.buffer_normal.combine(
|
||||||
|
x,
|
||||||
|
self.handle,
|
||||||
|
async_finish=self.async_finish,
|
||||||
|
previous_event=previous_event,
|
||||||
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
|
)
|
||||||
|
return combined_x, event
|
||||||
|
|
||||||
|
|
||||||
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||||
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||||
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||||
|
"""
|
||||||
|
# TODO(ch-wan): allow users to set this value
|
||||||
|
self.num_max_dispatch_tokens_per_rank = 128
|
||||||
|
self.buffer_low_latency = _get_buffer_low_latency(
|
||||||
|
self.group,
|
||||||
|
self.num_max_dispatch_tokens_per_rank,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_experts,
|
||||||
|
)
|
||||||
|
self.return_recv_hook = return_recv_hook
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
):
|
||||||
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
|
expected_m = (
|
||||||
|
hidden_states.shape[0]
|
||||||
|
* self.buffer_low_latency.group_size
|
||||||
|
* topk_idx.shape[1]
|
||||||
|
+ num_experts
|
||||||
|
) // num_experts
|
||||||
|
hidden_states, masked_m, event, hook = self._dispatch_low_latency(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
num_max_dispatch_tokens_per_rank,
|
||||||
|
num_experts,
|
||||||
|
use_fp8=True,
|
||||||
|
)
|
||||||
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# reorder_topk_ids = torch.empty(
|
||||||
|
# (0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
# )
|
||||||
|
# seg_indptr = torch.zeros(
|
||||||
|
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
# )
|
||||||
|
reorder_topk_ids = seg_indptr = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
reorder_topk_ids,
|
||||||
|
seg_indptr,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dispatch_low_latency(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -351,62 +448,17 @@ class DeepEPDispatcher:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
hidden_states, event, hook = self._combine_low_latency(
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
hidden_states,
|
||||||
if hidden_states.shape[0] > 0:
|
topk_idx,
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
topk_weights,
|
||||||
output = torch.empty(
|
)
|
||||||
(num_tokens, hidden_states.shape[1]),
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
||||||
hidden_states,
|
|
||||||
output,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = torch.zeros(
|
|
||||||
(0, hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
hidden_states, event = self.combine_normal(
|
|
||||||
output,
|
|
||||||
)
|
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
|
||||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
|
||||||
hidden_states, event, hook = self.combine_low_latency(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
)
|
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def combine_normal(self, x: torch.Tensor):
|
def _combine_low_latency(
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
|
||||||
|
|
||||||
combined_x, _, event = self.buffer_normal.combine(
|
|
||||||
x,
|
|
||||||
self.handle,
|
|
||||||
async_finish=self.async_finish,
|
|
||||||
previous_event=previous_event,
|
|
||||||
allocate_on_comm_stream=previous_event is not None,
|
|
||||||
)
|
|
||||||
return combined_x, event
|
|
||||||
|
|
||||||
def combine_low_latency(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
@@ -423,3 +475,80 @@ class DeepEPDispatcher:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return combined_hidden_states, event, hook
|
return combined_hidden_states, event, hook
|
||||||
|
|
||||||
|
|
||||||
|
class DeepEPDispatcher:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
group: torch.distributed.ProcessGroup,
|
||||||
|
router_topk: int,
|
||||||
|
permute_fusion: bool = False,
|
||||||
|
num_experts: int = None,
|
||||||
|
num_local_experts: int = None,
|
||||||
|
hidden_size: int = None,
|
||||||
|
params_dtype: torch.dtype = None,
|
||||||
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||||
|
async_finish: bool = False,
|
||||||
|
return_recv_hook: bool = False,
|
||||||
|
):
|
||||||
|
self.deepep_mode = deepep_mode
|
||||||
|
|
||||||
|
common_kwargs = dict(
|
||||||
|
group=group,
|
||||||
|
router_topk=router_topk,
|
||||||
|
permute_fusion=permute_fusion,
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.deepep_mode.enable_normal():
|
||||||
|
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
||||||
|
async_finish=async_finish,
|
||||||
|
**common_kwargs,
|
||||||
|
)
|
||||||
|
if self.deepep_mode.enable_low_latency():
|
||||||
|
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
||||||
|
return_recv_hook=return_recv_hook,
|
||||||
|
**common_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
num_max_dispatch_tokens_per_rank: int = 128,
|
||||||
|
forward_mode: ForwardMode = None,
|
||||||
|
) -> Tuple:
|
||||||
|
return self._get_dispatcher(forward_mode).dispatch(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def combine(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self._get_dispatcher(forward_mode).combine(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
||||||
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
|
return self._normal_dispatcher
|
||||||
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
|
return self._low_latency_dispatcher
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
|
|||||||
@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if global_server_args_dict["enable_deepep_moe"]
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||||
)
|
)
|
||||||
if not global_server_args_dict["enable_deepep_moe"]:
|
self.experts = MoEImpl(
|
||||||
self.experts = MoEImpl(
|
num_experts=config.n_routed_experts,
|
||||||
num_experts=config.n_routed_experts,
|
top_k=config.num_experts_per_tok,
|
||||||
top_k=config.num_experts_per_tok,
|
hidden_size=config.hidden_size,
|
||||||
hidden_size=config.hidden_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
renormalize=config.norm_topk_prob,
|
||||||
renormalize=config.norm_topk_prob,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
use_grouped_topk=True,
|
||||||
use_grouped_topk=True,
|
num_expert_group=config.n_group,
|
||||||
num_expert_group=config.n_group,
|
topk_group=config.topk_group,
|
||||||
topk_group=config.topk_group,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
prefix=add_prefix("experts", prefix),
|
||||||
prefix=add_prefix("experts", prefix),
|
**(
|
||||||
)
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||||
else:
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
self.experts = MoEImpl(
|
else {}
|
||||||
num_experts=config.n_routed_experts,
|
),
|
||||||
top_k=config.num_experts_per_tok,
|
)
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.moe_intermediate_size,
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
quant_config=quant_config,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
num_expert_group=config.n_group,
|
|
||||||
topk_group=config.topk_group,
|
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
|
||||||
prefix=add_prefix("experts", prefix),
|
|
||||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
|
|||||||
Reference in New Issue
Block a user