fix moe gate dtype, fix tbo, fix fake dispatch (#7825)
This commit is contained in:
@@ -66,7 +66,7 @@ def transform_select_experts_inputs(
|
||||
info: Optional[ExpertLocationDispatchInfo],
|
||||
):
|
||||
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
||||
router_logits = torch.randn_like(router_logits)
|
||||
router_logits.uniform_(5, 10)
|
||||
if correction_bias is not None:
|
||||
correction_bias = torch.zeros_like(correction_bias)
|
||||
return router_logits, correction_bias
|
||||
|
||||
@@ -499,7 +499,7 @@ def biased_grouped_topk_gpu(
|
||||
and is_power_of_two(correction_bias.shape[0])
|
||||
):
|
||||
topk_weights, topk_ids = moe_fused_gate(
|
||||
gating_output,
|
||||
gating_output.to(dtype=torch.float32),
|
||||
correction_bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
|
||||
@@ -229,7 +229,7 @@ class MoEGate(nn.Module):
|
||||
)
|
||||
if config.topk_method == "noaux_tc":
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts))
|
||||
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.e_score_correction_bias = None
|
||||
|
||||
Reference in New Issue
Block a user