Support aiter RMSNorm in AMD (#5510)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
This commit is contained in:
@@ -20,9 +20,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -32,8 +35,20 @@ if _is_cuda:
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
|
||||
|
||||
rmsnorm = rms_norm
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
eps: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
|
||||
return x, residual
|
||||
|
||||
|
||||
class RMSNorm(CustomOp):
|
||||
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
if not _is_cuda:
|
||||
if not (_is_cuda or _is_hip):
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user