[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)
This commit is contained in:
@@ -4,7 +4,7 @@ import unittest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utils import SiluAndMul, precision
|
||||
from utils import GeluAndMul, SiluAndMul, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase):
|
||||
N = [22016, 22018]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _activation_test(self, m, n, dtype):
|
||||
def _silu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase):
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="none")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_tanh_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="tanh")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_activation(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._activation_test(*params)
|
||||
self._silu_and_mul_test(*params)
|
||||
self._gelu_and_mul_test(*params)
|
||||
self._gelu_tanh_and_mul_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
|
||||
Reference in New Issue
Block a user