[fix] recover auto-dispatch for rmsnorm and rope (#6745)

This commit is contained in:
JieXin Liang
2025-06-04 12:44:20 +08:00
committed by GitHub
parent 37f1547587
commit 180ff5eecc
3 changed files with 28 additions and 30 deletions

View File

@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_cuda(
self,
x: torch.Tensor,
@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
def forward_native(
self,