fix fused_add_rms_norm bug
This commit is contained in:
@@ -175,7 +175,7 @@ class RMSNorm(CustomOp):
|
|||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return output, residual_out
|
return x, residual
|
||||||
except TypeError:
|
except TypeError:
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
output,
|
output,
|
||||||
@@ -185,7 +185,8 @@ class RMSNorm(CustomOp):
|
|||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return x, residual
|
return output, residual_out
|
||||||
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
|
|||||||
Reference in New Issue
Block a user