From bc5fc332f75d9182c1a1d123cf1fb7f940796334 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 8 Sep 2025 11:20:39 +0800 Subject: [PATCH] Fix slow fused add RMSNorm (#10141) --- python/sglang/srt/layers/layernorm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7743b888e..81ec3693a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -39,12 +39,8 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() if _is_cuda: - from sgl_kernel import ( - fused_add_rmsnorm, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - rmsnorm, - ) + from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm + from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm if _use_aiter: from aiter import rmsnorm2d_fwd as rms_norm @@ -86,7 +82,9 @@ class RMSNorm(CustomOp): if self.variance_size_override is not None: return self.forward_native(x, residual) if residual is not None: - fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + flashinfer_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) return out