From 93c6fb12c7736e318f7093e2b653b2071bfb8a10 Mon Sep 17 00:00:00 2001 From: michael-amd Date: Fri, 25 Apr 2025 13:48:55 -0700 Subject: [PATCH] Fix: deepseek forward absorb (#5723) Co-authored-by: ispobock --- python/sglang/srt/layers/layernorm.py | 45 ++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3c18cea70..87322b1b0 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,9 +20,10 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import ( @@ -32,6 +33,8 @@ if _is_cuda: rmsnorm, ) +if _is_hip: + from vllm._custom_ops import fused_add_rms_norm, rms_norm logger = logging.getLogger(__name__) @@ -46,23 +49,49 @@ class RMSNorm(CustomOp): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + def forward(self, *args, **kwargs): + if torch.compiler.is_compiling(): + return self.forward_native(*args, **kwargs) + if _is_cuda: + return self.forward_cuda(*args, **kwargs) + elif _is_hip: + return self.forward_hip(*args, **kwargs) + else: + return self.forward_native(*args, **kwargs) + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) return out + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + # NOTE: Romove this if aiter kernel supports discontinuous input + x = x.contiguous() + if residual is not None: + fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = torch.empty_like(x) + rms_norm(out, x, self.weight.data, self.variance_epsilon) + return out + def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: @@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp): self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps + def forward(self, *args, **kwargs): + if torch.compiler.is_compiling(): + return self.forward_native(*args, **kwargs) + if _is_cuda: + return self.forward_cuda(*args, **kwargs) + else: + return self.forward_native(*args, **kwargs) + def forward_native( self, x: torch.Tensor, @@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not _is_cuda: +if not (_is_cuda or _is_hip): logger.info( - "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm