fix fused_add_rms_norm bug

This commit is contained in:
maxiao
2025-10-27 10:27:57 +08:00
parent b80ae5e9ff
commit f9a026ad2b

View File

@@ -175,7 +175,7 @@ class RMSNorm(CustomOp):
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return x, residual
except TypeError:
fused_add_rms_norm(
output,
@@ -185,7 +185,8 @@ class RMSNorm(CustomOp):
self.weight.data,
self.variance_epsilon,
)
return x, residual
return output, residual_out
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)