[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -3,9 +3,12 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class TestGeluAndMul(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
|
||||
self._run_gelu_and_mul_test(*params)
|
||||
|
||||
|
||||
class TestQuickGELU(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
|
||||
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
layer = QuickGELU().to(dtype=dtype)
|
||||
|
||||
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
|
||||
|
||||
with torch.inference_mode():
|
||||
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
|
||||
if _is_hip:
|
||||
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
|
||||
else:
|
||||
out = layer.forward_cuda(x)
|
||||
|
||||
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
|
||||
self.assertTrue(
|
||||
torch.allclose(out, ref, atol=tol, rtol=tol),
|
||||
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
|
||||
)
|
||||
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
|
||||
|
||||
def test_quick_gelu(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
dim=params[1],
|
||||
dtype=params[2],
|
||||
seed=params[3],
|
||||
):
|
||||
self._run_gelu_quick_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user