diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 88524ffae..3c18cea70 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,12 +20,9 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda, is_hip - -logger = logging.getLogger(__name__) +from sglang.srt.utils import is_cuda _is_cuda = is_cuda() -_is_hip = is_hip() if _is_cuda: from sgl_kernel import ( @@ -35,20 +32,8 @@ if _is_cuda: rmsnorm, ) -if _is_hip: - from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add - - rmsnorm = rms_norm - - def fused_add_rmsnorm( - x: torch.Tensor, - residual: torch.Tensor, - w: torch.Tensor, - eps: float, - ) -> Tuple[torch.Tensor, torch.Tensor]: - rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps) - return x, residual +logger = logging.getLogger(__name__) class RMSNorm(CustomOp): @@ -154,7 +139,7 @@ class Gemma3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not (_is_cuda or _is_hip): +if not _is_cuda: logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." )