diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py index e05d88b32..899518034 100644 --- a/python/sglang/srt/layers/elementwise.py +++ b/python/sglang/srt/layers/elementwise.py @@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune( def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False): assert len(x.shape) == 2 - assert x.shape == residual.shape and x.dtype == residual.dtype + assert ( + x.shape == residual.shape and x.dtype == residual.dtype + ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}" output, mid = torch.empty_like(x), torch.empty_like(x) bs, hidden_dim = x.shape if autotune: