Remve router gemm output dtype conversion (#8204)
This commit is contained in:
@@ -254,9 +254,8 @@ class MoEGate(nn.Module):
|
|||||||
and self.weight.shape[0] == 256
|
and self.weight.shape[0] == 256
|
||||||
and _device_sm >= 90
|
and _device_sm >= 90
|
||||||
):
|
):
|
||||||
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
# router gemm output float32
|
||||||
hidden_states.dtype
|
logits = dsv3_router_gemm(hidden_states, self.weight)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logits = F.linear(hidden_states, self.weight, None)
|
logits = F.linear(hidden_states, self.weight, None)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user