Support aiter RMSNorm in AMD (#5510)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
This commit is contained in:
@@ -20,9 +20,12 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -32,8 +35,20 @@ if _is_cuda:
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _is_hip:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
|
||||||
|
|
||||||
|
rmsnorm = rms_norm
|
||||||
|
|
||||||
|
def fused_add_rmsnorm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
eps: float,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
if not _is_cuda:
|
if not (_is_cuda or _is_hip):
|
||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user