AITER backend extension and workload optimizations (#6838)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -20,10 +20,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -33,7 +34,10 @@ if _is_cuda:
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
from aiter import rmsnorm2d_fwd as rms_norm
|
||||
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||
elif _is_hip:
|
||||
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
if _use_aiter:
|
||||
self._forward_method = self.forward_aiter
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
|
||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_aiter(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
residual_out = torch.empty_like(x)
|
||||
output = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual,
|
||||
residual_out,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user