Remve router gemm output dtype conversion (#8204)
This commit is contained in:
@@ -254,9 +254,8 @@ class MoEGate(nn.Module):
|
||||
and self.weight.shape[0] == 256
|
||||
and _device_sm >= 90
|
||||
):
|
||||
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
# router gemm output float32
|
||||
logits = dsv3_router_gemm(hidden_states, self.weight)
|
||||
else:
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user