[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)
This commit is contained in:
@@ -110,6 +110,14 @@ class GeluAndMul(CustomOp):
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if _is_cpu_amx_available and self.approximate == "tanh":
|
||||
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||
elif _is_cpu_amx_available and self.approximate == "none":
|
||||
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||
else:
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[sqrt_2_div_pi](float x) {
|
||||
float x3 = x * x * x;
|
||||
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
|
||||
return 0.5f * x * (1.f + std::tanh(tanh_arg));
|
||||
},
|
||||
[sqrt_2_div_pi](Vec x) {
|
||||
Vec x3 = x * x * x;
|
||||
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
|
||||
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
|
||||
});
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
|
||||
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,10 @@ limitations under the License.
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// gelu_and_mul
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
|
||||
|
||||
// l2norm
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||
|
||||
@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
|
||||
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
|
||||
|
||||
// norm
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utils import SiluAndMul, precision
|
||||
from utils import GeluAndMul, SiluAndMul, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase):
|
||||
N = [22016, 22018]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _activation_test(self, m, n, dtype):
|
||||
def _silu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase):
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="none")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_tanh_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="tanh")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_activation(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._activation_test(*params)
|
||||
self._silu_and_mul_test(*params)
|
||||
self._gelu_and_mul_test(*params)
|
||||
self._gelu_tanh_and_mul_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
|
||||
Reference in New Issue
Block a user