fix fused_add_rms_norm bug
This commit is contained in:
@@ -175,7 +175,7 @@ class RMSNorm(CustomOp):
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return x, residual
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
@@ -185,7 +185,8 @@ class RMSNorm(CustomOp):
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
return output, residual_out
|
||||
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
Reference in New Issue
Block a user