diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 37832a3f7..3d973393e 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -110,6 +110,14 @@ class GeluAndMul(CustomOp): d = x.shape[-1] // 2 return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: + if _is_cpu_amx_available and self.approximate == "tanh": + return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x) + elif _is_cpu_amx_available and self.approximate == "none": + return torch.ops.sgl_kernel.gelu_and_mul_cpu(x) + else: + return self.forward_native(x) + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return self._forward_impl(x) diff --git a/sgl-kernel/csrc/cpu/activation.cpp b/sgl-kernel/csrc/cpu/activation.cpp index debf5b244..70756776b 100644 --- a/sgl-kernel/csrc/cpu/activation.cpp +++ b/sgl-kernel/csrc/cpu/activation.cpp @@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) { }); return out; } + +at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) { + RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector({input})); + auto sizes = input.sizes().vec(); + int64_t last_dim = input.ndimension() - 1; + int64_t d = sizes[last_dim] / 2; + sizes[last_dim] = d; + int64_t num_tokens = input.numel() / input.size(-1); + at::Tensor out = at::empty(sizes, input.options()); + const float sqrt_2_div_pi = std::sqrt(2.f / M_PI); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] { + using Vec = at::vec::Vectorized; + act_and_mul_kernel_impl( + out.data_ptr(), + input.data_ptr(), + num_tokens, + d, + [sqrt_2_div_pi](float x) { + float x3 = x * x * x; + float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3); + return 0.5f * x * (1.f + std::tanh(tanh_arg)); + }, + [sqrt_2_div_pi](Vec x) { + Vec x3 = x * x * x; + Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3); + return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh()); + }); + }); + + return out; +} + +at::Tensor gelu_and_mul_cpu(const at::Tensor& input) { + RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector({input})); + auto sizes = input.sizes().vec(); + int64_t last_dim = input.ndimension() - 1; + int64_t d = sizes[last_dim] / 2; + sizes[last_dim] = d; + int64_t num_tokens = input.numel() / input.size(-1); + at::Tensor out = at::empty(sizes, input.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] { + using Vec = at::vec::Vectorized; + const float inv_sqrt2 = 1.0f / std::sqrt(2.0f); + act_and_mul_kernel_impl( + out.data_ptr(), + input.data_ptr(), + num_tokens, + d, + [inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); }, + [inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); }); + }); + + return out; +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 872c07628..2c8d9e3ec 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -23,6 +23,10 @@ limitations under the License. // silu_and_mul at::Tensor silu_and_mul_cpu(at::Tensor& input); +// gelu_and_mul +at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input); +at::Tensor gelu_and_mul_cpu(const at::Tensor& input); + // l2norm at::Tensor l2norm_cpu(at::Tensor& input, double eps); @@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu); + m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor"); + m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu); + m.def("gelu_and_mul_cpu(Tensor input) -> Tensor"); + m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu); // norm m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); diff --git a/test/srt/cpu/test_activation.py b/test/srt/cpu/test_activation.py index 23af99940..1234fc631 100644 --- a/test/srt/cpu/test_activation.py +++ b/test/srt/cpu/test_activation.py @@ -4,7 +4,7 @@ import unittest import sgl_kernel import torch import torch.nn.functional as F -from utils import SiluAndMul, precision +from utils import GeluAndMul, SiluAndMul, precision from sglang.test.test_utils import CustomTestCase @@ -16,7 +16,7 @@ class TestActivation(CustomTestCase): N = [22016, 22018] dtype = [torch.float16, torch.bfloat16] - def _activation_test(self, m, n, dtype): + def _silu_and_mul_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) @@ -25,10 +25,30 @@ class TestActivation(CustomTestCase): atol = rtol = precision[ref_out.dtype] torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + def _gelu_and_mul_test(self, m, n, dtype): + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x) + ref_out = GeluAndMul(x, approximate="none") + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def _gelu_tanh_and_mul_test(self, m, n, dtype): + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x) + ref_out = GeluAndMul(x, approximate="tanh") + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + def test_activation(self): for params in itertools.product(self.M, self.N, self.dtype): with self.subTest(m=params[0], n=params[1], dtype=params[2]): - self._activation_test(*params) + self._silu_and_mul_test(*params) + self._gelu_and_mul_test(*params) + self._gelu_tanh_and_mul_test(*params) if __name__ == "__main__": diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index b16b81bbf..6435dad74 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] +def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] + + def per_token_quant_int8(x): x = x.float() absmax = x.abs().max(dim=-1).values