diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 34d6eb55a..528b08323 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -175,7 +175,7 @@ class RMSNorm(CustomOp): self.weight.data, self.variance_epsilon, ) - return output, residual_out + return x, residual except TypeError: fused_add_rms_norm( output, @@ -185,7 +185,8 @@ class RMSNorm(CustomOp): self.weight.data, self.variance_epsilon, ) - return x, residual + return output, residual_out + out = torch.empty_like(x) rms_norm(out, x, self.weight.data, self.variance_epsilon)