[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)

Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com>
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
Hubert Lu
2025-07-24 23:44:28 -07:00
committed by GitHub
parent 7ad6b766c5
commit af4b9bae95
17 changed files with 1226 additions and 61 deletions

View File

@@ -33,6 +33,7 @@ from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
set_weight_attrs,
)
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
if is_npu():
import torch_npu
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
return self.forward_native(x)
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
gelu_quick(x, out)
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
logger.info(
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul

View File

@@ -3,9 +3,12 @@ import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
from sglang.srt.utils import is_hip
from sglang.test.test_utils import CustomTestCase
_is_hip = is_hip()
class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
self._run_gelu_and_mul_test(*params)
class TestQuickGELU(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
torch.manual_seed(seed)
layer = QuickGELU().to(dtype=dtype)
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
with torch.inference_mode():
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
if _is_hip:
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
else:
out = layer.forward_cuda(x)
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
self.assertTrue(
torch.allclose(out, ref, atol=tol, rtol=tol),
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
)
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
def test_quick_gelu(self):
for params in itertools.product(
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
):
with self.subTest(
num_tokens=params[0],
dim=params[1],
dtype=params[2],
seed=params[3],
):
self._run_gelu_quick_test(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)