[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
elif _is_hip:
|
||||
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
if is_npu():
|
||||
import torch_npu
|
||||
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
||||
gelu_quick(x, out)
|
||||
return out
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
@@ -3,9 +3,12 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class TestGeluAndMul(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
|
||||
self._run_gelu_and_mul_test(*params)
|
||||
|
||||
|
||||
class TestQuickGELU(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
|
||||
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
layer = QuickGELU().to(dtype=dtype)
|
||||
|
||||
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
|
||||
|
||||
with torch.inference_mode():
|
||||
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
|
||||
if _is_hip:
|
||||
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
|
||||
else:
|
||||
out = layer.forward_cuda(x)
|
||||
|
||||
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
|
||||
self.assertTrue(
|
||||
torch.allclose(out, ref, atol=tol, rtol=tol),
|
||||
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
|
||||
)
|
||||
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
|
||||
|
||||
def test_quick_gelu(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
dim=params[1],
|
||||
dtype=params[2],
|
||||
seed=params[3],
|
||||
):
|
||||
self._run_gelu_quick_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user