Small refactor DeepEPDispatcher into subclasses (#4994)
This commit is contained in:
@@ -23,7 +23,7 @@ _buffer_normal = None
|
||||
_buffer_low_latency = None
|
||||
|
||||
|
||||
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
||||
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
||||
"""
|
||||
Copy from DeepEP example usage in model inference prefilling.
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
||||
@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
||||
return _buffer_normal
|
||||
|
||||
|
||||
def get_buffer_low_latency(
|
||||
def _get_buffer_low_latency(
|
||||
group: dist.ProcessGroup,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
@@ -85,24 +85,16 @@ def get_buffer_low_latency(
|
||||
return _buffer_low_latency
|
||||
|
||||
|
||||
class DeepEPDispatcher:
|
||||
"""
|
||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||
"""
|
||||
|
||||
class _DeepEPDispatcherImplBase:
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool = False,
|
||||
num_experts: int = None,
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
async_finish: bool = False,
|
||||
return_recv_hook: bool = False,
|
||||
permute_fusion: bool,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
):
|
||||
if not use_deepep:
|
||||
raise ImportError(
|
||||
@@ -119,115 +111,71 @@ class DeepEPDispatcher:
|
||||
self.params_dtype = params_dtype
|
||||
self.params_bytes = 2
|
||||
|
||||
self.deepep_mode = deepep_mode
|
||||
self.handle = None
|
||||
|
||||
if self.deepep_mode.enable_normal():
|
||||
self.buffer_normal = get_buffer_normal(
|
||||
self.group, self.hidden_size * self.params_bytes
|
||||
)
|
||||
self.async_finish = async_finish
|
||||
self.src2dst = None
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
"""
|
||||
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
# TODO(ch-wan): allow users to set this value
|
||||
self.num_max_dispatch_tokens_per_rank = 128
|
||||
self.buffer_low_latency = get_buffer_low_latency(
|
||||
self.group,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
self.return_recv_hook = return_recv_hook
|
||||
|
||||
def deepep_permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
fp8_dtype: Optional[torch.dtype] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_block_quant: bool = False,
|
||||
):
|
||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_idx, self.num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
(int(num_total_tokens), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
# PreReorder
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
None,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
forward_mode: ForwardMode = None,
|
||||
) -> Tuple:
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
masked_m = torch.empty(
|
||||
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
expected_m = 0
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
event,
|
||||
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
expected_m = (
|
||||
hidden_states.shape[0]
|
||||
* self.buffer_low_latency.group_size
|
||||
* topk_idx.shape[1]
|
||||
+ num_experts
|
||||
) // num_experts
|
||||
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=True,
|
||||
def combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
def __init__(self, async_finish: bool, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.buffer_normal = _get_buffer_normal(
|
||||
self.group, self.hidden_size * self.params_bytes
|
||||
)
|
||||
self.async_finish = async_finish
|
||||
self.src2dst = None
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
event,
|
||||
) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# TODO
|
||||
# masked_m = torch.empty(
|
||||
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
# expected_m = 0
|
||||
masked_m = expected_m = None
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
@@ -239,7 +187,7 @@ class DeepEPDispatcher:
|
||||
expected_m,
|
||||
)
|
||||
|
||||
def dispatch_normal(
|
||||
def _dispatch_normal(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
@@ -292,7 +240,156 @@ class DeepEPDispatcher:
|
||||
event,
|
||||
)
|
||||
|
||||
def dispatch_low_latency(
|
||||
def _deepep_permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
fp8_dtype: Optional[torch.dtype] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_block_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||
"""
|
||||
|
||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_idx, self.num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
(int(num_total_tokens), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
# PreReorder
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
None,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states, event = self._combine_normal(
|
||||
output,
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_normal(self, x: torch.Tensor):
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
x,
|
||||
self.handle,
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
|
||||
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def __init__(self, return_recv_hook: bool, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
"""
|
||||
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
# TODO(ch-wan): allow users to set this value
|
||||
self.num_max_dispatch_tokens_per_rank = 128
|
||||
self.buffer_low_latency = _get_buffer_low_latency(
|
||||
self.group,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
self.return_recv_hook = return_recv_hook
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
expected_m = (
|
||||
hidden_states.shape[0]
|
||||
* self.buffer_low_latency.group_size
|
||||
* topk_idx.shape[1]
|
||||
+ num_experts
|
||||
) // num_experts
|
||||
hidden_states, masked_m, event, hook = self._dispatch_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=True,
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
# TODO
|
||||
# reorder_topk_ids = torch.empty(
|
||||
# (0,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
# seg_indptr = torch.zeros(
|
||||
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
reorder_topk_ids = seg_indptr = None
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
def _dispatch_low_latency(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
@@ -351,62 +448,17 @@ class DeepEPDispatcher:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
) -> torch.Tensor:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states, event = self.combine_normal(
|
||||
output,
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
hidden_states, event, hook = self.combine_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
hidden_states, event, hook = self._combine_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
return hidden_states
|
||||
|
||||
def combine_normal(self, x: torch.Tensor):
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
x,
|
||||
self.handle,
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
def combine_low_latency(
|
||||
def _combine_low_latency(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
@@ -423,3 +475,80 @@ class DeepEPDispatcher:
|
||||
)
|
||||
)
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
|
||||
class DeepEPDispatcher:
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool = False,
|
||||
num_experts: int = None,
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
async_finish: bool = False,
|
||||
return_recv_hook: bool = False,
|
||||
):
|
||||
self.deepep_mode = deepep_mode
|
||||
|
||||
common_kwargs = dict(
|
||||
group=group,
|
||||
router_topk=router_topk,
|
||||
permute_fusion=permute_fusion,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
)
|
||||
|
||||
if self.deepep_mode.enable_normal():
|
||||
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
||||
async_finish=async_finish,
|
||||
**common_kwargs,
|
||||
)
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
||||
return_recv_hook=return_recv_hook,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
forward_mode: ForwardMode = None,
|
||||
) -> Tuple:
|
||||
return self._get_dispatcher(forward_mode).dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
) -> torch.Tensor:
|
||||
return self._get_dispatcher(forward_mode).combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
|
||||
def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
return self._normal_dispatcher
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
return self._low_latency_dispatcher
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
else:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
)
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
|
||||
Reference in New Issue
Block a user