Support both approximate and exact expert distribution collection (#6964)
This commit is contained in:
@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
|
||||
Reference in New Issue
Block a user