From 180ff5eecc2da2231eb3ef29f70aa8d62fd8e168 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Wed, 4 Jun 2025 12:44:20 +0800 Subject: [PATCH] [fix] recover auto-dispatch for rmsnorm and rope (#6745) --- python/sglang/srt/custom_op.py | 20 +++++++++++++++++++- python/sglang/srt/layers/layernorm.py | 20 +++----------------- python/sglang/srt/layers/rotary_embedding.py | 18 ++++++------------ 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index fe6176f4e..ba34dc8e4 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -11,7 +11,20 @@ class CustomOp(nn.Module): super().__init__() self._forward_method = self.dispatch_forward() + # States for torch.compile + self._original_forward_method = None + self.is_torch_compile = False + def enter_torch_compile(self, num_tokens: int): + # Skip if Op is already entered compile mode. + # NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused + # among layers and `enter_torch_compile` will be called many times. + # We should prevent `self._original_forward_method` from being overridden when + # it is not the first time `enter_torch_compile` called. + if self.is_torch_compile: + return + + self._original_forward_method = self._forward_method # NOTE: Temporarily workaround MoE if "FusedMoE" in self.__class__.__name__: if num_tokens == 1: @@ -27,7 +40,12 @@ class CustomOp(nn.Module): self.is_torch_compile = True def leave_torch_compile(self): - self._forward_method = self.forward_cuda + # Skip if Op is already exited compile mode. + if not self.is_torch_compile: + return + + self._forward_method = self._original_forward_method + self._original_forward_method = None self.is_torch_compile = False # Please do not override this method, because `self._forward_method` can change when in torch compile mode diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 98ae3d83d..1f398549b 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -49,16 +49,6 @@ class RMSNorm(CustomOp): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, *args, **kwargs): - if torch.compiler.is_compiling(): - return self.forward_native(*args, **kwargs) - if _is_cuda: - return self.forward_cuda(*args, **kwargs) - elif _is_hip: - return self.forward_hip(*args, **kwargs) - else: - return self.forward_native(*args, **kwargs) - def forward_cuda( self, x: torch.Tensor, @@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp): self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - def forward(self, *args, **kwargs): - if torch.compiler.is_compiling(): - return self.forward_native(*args, **kwargs) - if _is_cuda: - return self.forward_cuda(*args, **kwargs) - else: - return self.forward_native(*args, **kwargs) + # Re-dispatch + if _is_hip: + self._forward_method = self.forward_native def forward_native( self, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c5c285ca0..8ae191b51 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -8,9 +8,10 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace @@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) + # Re-dispatch + if _is_hip: + self._forward_method = self.forward_native + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base ** ( torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) @@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache - def forward_hip(self, *args, **kwargs): - return self.forward_native(*args, **kwargs) - - def forward(self, *args, **kwargs): - if torch.compiler.is_compiling(): - return self.forward_native(*args, **kwargs) - if _is_cuda: - return self.forward_cuda(*args, **kwargs) - else: - return self.forward_native(*args, **kwargs) - def forward_native( self, positions: torch.Tensor,