feat: use gelu_tanh_and_mul (#1193)

This commit is contained in:
Yineng Zhang
2024-08-24 18:58:16 +10:00
committed by GitHub
parent a5b14ad043
commit c9064e6fd9
3 changed files with 74 additions and 3 deletions

View File

@@ -0,0 +1,55 @@
import itertools
import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
class TestGeluAndMul(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed):
torch.manual_seed(seed)
layer = GeluAndMul().to(dtype=dtype)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
with torch.inference_mode():
ref_out = layer.forward_native(x)
out = layer.forward_cuda(x)
if dtype == torch.bfloat16:
atol = rtol = 1e-2
else:
atol = rtol = 1e-3
self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol))
def test_gelu_and_mul(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
dtype=params[2],
seed=params[3],
):
self._run_gelu_and_mul_test(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)