[fix] recover auto-dispatch for rmsnorm and rope (#6745)
This commit is contained in:
@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
|
# States for torch.compile
|
||||||
|
self._original_forward_method = None
|
||||||
|
self.is_torch_compile = False
|
||||||
|
|
||||||
def enter_torch_compile(self, num_tokens: int):
|
def enter_torch_compile(self, num_tokens: int):
|
||||||
|
# Skip if Op is already entered compile mode.
|
||||||
|
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
|
||||||
|
# among layers and `enter_torch_compile` will be called many times.
|
||||||
|
# We should prevent `self._original_forward_method` from being overridden when
|
||||||
|
# it is not the first time `enter_torch_compile` called.
|
||||||
|
if self.is_torch_compile:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._original_forward_method = self._forward_method
|
||||||
# NOTE: Temporarily workaround MoE
|
# NOTE: Temporarily workaround MoE
|
||||||
if "FusedMoE" in self.__class__.__name__:
|
if "FusedMoE" in self.__class__.__name__:
|
||||||
if num_tokens == 1:
|
if num_tokens == 1:
|
||||||
@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
|
|||||||
self.is_torch_compile = True
|
self.is_torch_compile = True
|
||||||
|
|
||||||
def leave_torch_compile(self):
|
def leave_torch_compile(self):
|
||||||
self._forward_method = self.forward_cuda
|
# Skip if Op is already exited compile mode.
|
||||||
|
if not self.is_torch_compile:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._forward_method = self._original_forward_method
|
||||||
|
self._original_forward_method = None
|
||||||
self.is_torch_compile = False
|
self.is_torch_compile = False
|
||||||
|
|
||||||
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
|
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
|
||||||
|
|||||||
@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
|
|||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if torch.compiler.is_compiling():
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
if _is_cuda:
|
|
||||||
return self.forward_cuda(*args, **kwargs)
|
|
||||||
elif _is_hip:
|
|
||||||
return self.forward_hip(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
# Re-dispatch
|
||||||
if torch.compiler.is_compiling():
|
if _is_hip:
|
||||||
return self.forward_native(*args, **kwargs)
|
self._forward_method = self.forward_native
|
||||||
if _is_cuda:
|
|
||||||
return self.forward_cuda(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Re-dispatch
|
||||||
|
if _is_hip:
|
||||||
|
self._forward_method = self.forward_native
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
pos_freqs = self.base ** (
|
pos_freqs = self.base ** (
|
||||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
||||||
@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def forward_hip(self, *args, **kwargs):
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if torch.compiler.is_compiling():
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
if _is_cuda:
|
|
||||||
return self.forward_cuda(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user