fix moe gate dtype, fix tbo, fix fake dispatch (#7825)
This commit is contained in:
@@ -66,7 +66,7 @@ def transform_select_experts_inputs(
|
|||||||
info: Optional[ExpertLocationDispatchInfo],
|
info: Optional[ExpertLocationDispatchInfo],
|
||||||
):
|
):
|
||||||
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
||||||
router_logits = torch.randn_like(router_logits)
|
router_logits.uniform_(5, 10)
|
||||||
if correction_bias is not None:
|
if correction_bias is not None:
|
||||||
correction_bias = torch.zeros_like(correction_bias)
|
correction_bias = torch.zeros_like(correction_bias)
|
||||||
return router_logits, correction_bias
|
return router_logits, correction_bias
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ def biased_grouped_topk_gpu(
|
|||||||
and is_power_of_two(correction_bias.shape[0])
|
and is_power_of_two(correction_bias.shape[0])
|
||||||
):
|
):
|
||||||
topk_weights, topk_ids = moe_fused_gate(
|
topk_weights, topk_ids = moe_fused_gate(
|
||||||
gating_output,
|
gating_output.to(dtype=torch.float32),
|
||||||
correction_bias,
|
correction_bias,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ class MoEGate(nn.Module):
|
|||||||
)
|
)
|
||||||
if config.topk_method == "noaux_tc":
|
if config.topk_method == "noaux_tc":
|
||||||
self.e_score_correction_bias = nn.Parameter(
|
self.e_score_correction_bias = nn.Parameter(
|
||||||
torch.empty((config.n_routed_experts))
|
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.e_score_correction_bias = None
|
self.e_score_correction_bias = None
|
||||||
|
|||||||
Reference in New Issue
Block a user