Revert "Support aiter RMSNorm in AMD" (#5646)
This commit is contained in:
@@ -20,12 +20,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -35,20 +32,8 @@ if _is_cuda:
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip:
|
|
||||||
|
|
||||||
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
rmsnorm = rms_norm
|
|
||||||
|
|
||||||
def fused_add_rmsnorm(
|
|
||||||
x: torch.Tensor,
|
|
||||||
residual: torch.Tensor,
|
|
||||||
w: torch.Tensor,
|
|
||||||
eps: float,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
|
|
||||||
return x, residual
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
@@ -154,7 +139,7 @@ class Gemma3RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
if not (_is_cuda or _is_hip):
|
if not _is_cuda:
|
||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user