[fix] recover auto-dispatch for rmsnorm and rope (#6745)
This commit is contained in:
@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
|
||||
# States for torch.compile
|
||||
self._original_forward_method = None
|
||||
self.is_torch_compile = False
|
||||
|
||||
def enter_torch_compile(self, num_tokens: int):
|
||||
# Skip if Op is already entered compile mode.
|
||||
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
|
||||
# among layers and `enter_torch_compile` will be called many times.
|
||||
# We should prevent `self._original_forward_method` from being overridden when
|
||||
# it is not the first time `enter_torch_compile` called.
|
||||
if self.is_torch_compile:
|
||||
return
|
||||
|
||||
self._original_forward_method = self._forward_method
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in self.__class__.__name__:
|
||||
if num_tokens == 1:
|
||||
@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
|
||||
self.is_torch_compile = True
|
||||
|
||||
def leave_torch_compile(self):
|
||||
self._forward_method = self.forward_cuda
|
||||
# Skip if Op is already exited compile mode.
|
||||
if not self.is_torch_compile:
|
||||
return
|
||||
|
||||
self._forward_method = self._original_forward_method
|
||||
self._original_forward_method = None
|
||||
self.is_torch_compile = False
|
||||
|
||||
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
|
||||
|
||||
@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(*args, **kwargs)
|
||||
if _is_cuda:
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
elif _is_hip:
|
||||
return self.forward_hip(*args, **kwargs)
|
||||
else:
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(*args, **kwargs)
|
||||
if _is_cuda:
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
else:
|
||||
return self.forward_native(*args, **kwargs)
|
||||
# Re-dispatch
|
||||
if _is_hip:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
|
||||
@@ -8,9 +8,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
# Re-dispatch
|
||||
if _is_hip:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
||||
@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(*args, **kwargs)
|
||||
if _is_cuda:
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
else:
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user