Refactor DeepSeek MoE layer to unify the two forward branches (#6325)
This commit is contained in:
@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
return (
|
||||
(forward_mode is not None)
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
self.topk_group = config.topk_group
|
||||
self.num_expert_group = config.n_group
|
||||
@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _enable_deepep_moe(self):
|
||||
return global_server_args_dict["enable_deepep_moe"]
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
forward_mode = forward_batch.forward_mode
|
||||
shared_output = None
|
||||
if (
|
||||
forward_mode is not None
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||
forward_mode, hidden_states
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
else:
|
||||
router_logits = None
|
||||
|
||||
if (self.n_share_experts_fusion == 0) and (
|
||||
(not self._enable_deepep_moe)
|
||||
or is_non_idle_and_non_empty(forward_mode, hidden_states)
|
||||
):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
if self._enable_deepep_moe and (router_logits is not None):
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
(
|
||||
hidden_states,
|
||||
@@ -357,36 +356,41 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
|
||||
if self._enable_deepep_moe:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
else:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode,
|
||||
)
|
||||
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states
|
||||
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def _forward_shared_experts(self, hidden_states):
|
||||
if self.n_share_experts_fusion == 0:
|
||||
return self.shared_experts(hidden_states)
|
||||
else:
|
||||
return None
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
|
||||
Reference in New Issue
Block a user