diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 4e3d39e77..59489cdb8 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -26,6 +26,7 @@ from sglang.srt.utils import ( get_bool_env_var, is_cpu, is_cuda, + is_flashinfer_available, is_hip, is_npu, is_xpu, @@ -33,6 +34,7 @@ from sglang.srt.utils import ( ) _is_cuda = is_cuda() +_is_flashinfer_available = is_flashinfer_available() _is_hip = is_hip() _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -41,7 +43,10 @@ _is_cpu = is_cpu() _is_xpu = is_xpu() if _is_cuda: - from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm + if _is_flashinfer_available: + from flashinfer.norm import fused_add_rmsnorm + else: + from sgl_kernel import fused_add_rmsnorm from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm if _use_aiter: @@ -84,9 +89,7 @@ class RMSNorm(CustomOp): if self.variance_size_override is not None: return self.forward_native(x, residual) if residual is not None: - flashinfer_fused_add_rmsnorm( - x, residual, self.weight.data, self.variance_epsilon - ) + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) return out