Refactor DeepSeek MoE layer to unify the two forward branches (#6325)
This commit is contained in:
@@ -194,6 +194,14 @@ class MoEGate(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||||
|
return (
|
||||||
|
(forward_mode is not None)
|
||||||
|
and not forward_mode.is_idle()
|
||||||
|
and hidden_states.shape[0] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -259,11 +267,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
self.ep_size = get_tensor_model_parallel_world_size()
|
self.ep_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_experts = config.n_routed_experts
|
self.num_experts = config.n_routed_experts
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.renormalize = config.norm_topk_prob
|
self.renormalize = config.norm_topk_prob
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
self.num_expert_group = config.n_group
|
self.num_expert_group = config.n_group
|
||||||
@@ -286,41 +295,30 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _enable_deepep_moe(self):
|
||||||
|
return global_server_args_dict["enable_deepep_moe"]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not global_server_args_dict["enable_deepep_moe"]:
|
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||||
return self.forward_normal(hidden_states)
|
forward_mode, hidden_states
|
||||||
else:
|
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
|
||||||
router_logits = self.gate(hidden_states)
|
|
||||||
final_hidden_states = self.experts(
|
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
|
||||||
)
|
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
|
||||||
if shared_output is not None:
|
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
|
||||||
if self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
def forward_deepep(
|
|
||||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
|
||||||
) -> torch.Tensor:
|
|
||||||
forward_mode = forward_batch.forward_mode
|
|
||||||
shared_output = None
|
|
||||||
if (
|
|
||||||
forward_mode is not None
|
|
||||||
and not forward_mode.is_idle()
|
|
||||||
and hidden_states.shape[0] > 0
|
|
||||||
):
|
):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
else:
|
||||||
|
router_logits = None
|
||||||
|
|
||||||
|
if (self.n_share_experts_fusion == 0) and (
|
||||||
|
(not self._enable_deepep_moe)
|
||||||
|
or is_non_idle_and_non_empty(forward_mode, hidden_states)
|
||||||
|
):
|
||||||
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
if self._enable_deepep_moe and (router_logits is not None):
|
||||||
topk_weights, topk_idx = select_experts(
|
topk_weights, topk_idx = select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -340,7 +338,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_weights = torch.empty(
|
topk_weights = torch.empty(
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
|
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||||
(
|
(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -357,36 +356,41 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
final_hidden_states = self.experts(
|
|
||||||
hidden_states=hidden_states,
|
if self._enable_deepep_moe:
|
||||||
topk_idx=topk_idx,
|
final_hidden_states = self.experts(
|
||||||
topk_weights=topk_weights,
|
hidden_states=hidden_states,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
topk_idx=topk_idx,
|
||||||
seg_indptr=seg_indptr,
|
topk_weights=topk_weights,
|
||||||
masked_m=masked_m,
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
expected_m=expected_m,
|
seg_indptr=seg_indptr,
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
masked_m=masked_m,
|
||||||
forward_mode=forward_mode,
|
expected_m=expected_m,
|
||||||
)
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||||
if self.ep_size > 1:
|
forward_mode=forward_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
final_hidden_states,
|
final_hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
forward_mode,
|
forward_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|
||||||
return final_hidden_states
|
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
def _forward_shared_experts(self, hidden_states):
|
return final_hidden_states
|
||||||
if self.n_share_experts_fusion == 0:
|
|
||||||
return self.shared_experts(hidden_states)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||||
|
|||||||
Reference in New Issue
Block a user