Remve router gemm output dtype conversion (#8204)

This commit is contained in:
Ke Bao
2025-07-21 15:37:00 +08:00
committed by GitHub
parent 9b5de6cb06
commit 6936be3221

View File

@@ -254,9 +254,8 @@ class MoEGate(nn.Module):
and self.weight.shape[0] == 256
and _device_sm >= 90
):
logits = dsv3_router_gemm(hidden_states, self.weight).to(
hidden_states.dtype
)
# router gemm output float32
logits = dsv3_router_gemm(hidden_states, self.weight)
else:
logits = F.linear(hidden_states, self.weight, None)