[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)
This commit is contained in:
@@ -110,6 +110,14 @@ class GeluAndMul(CustomOp):
|
|||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||||
|
|
||||||
|
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if _is_cpu_amx_available and self.approximate == "tanh":
|
||||||
|
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||||
|
elif _is_cpu_amx_available and self.approximate == "none":
|
||||||
|
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||||
|
else:
|
||||||
|
return self.forward_native(x)
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self._forward_impl(x)
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|||||||
@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
|||||||
});
|
});
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
|
||||||
|
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||||
|
auto sizes = input.sizes().vec();
|
||||||
|
int64_t last_dim = input.ndimension() - 1;
|
||||||
|
int64_t d = sizes[last_dim] / 2;
|
||||||
|
sizes[last_dim] = d;
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
at::Tensor out = at::empty(sizes, input.options());
|
||||||
|
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
|
||||||
|
|
||||||
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
|
||||||
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
act_and_mul_kernel_impl(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
num_tokens,
|
||||||
|
d,
|
||||||
|
[sqrt_2_div_pi](float x) {
|
||||||
|
float x3 = x * x * x;
|
||||||
|
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
|
||||||
|
return 0.5f * x * (1.f + std::tanh(tanh_arg));
|
||||||
|
},
|
||||||
|
[sqrt_2_div_pi](Vec x) {
|
||||||
|
Vec x3 = x * x * x;
|
||||||
|
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
|
||||||
|
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
|
||||||
|
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||||
|
auto sizes = input.sizes().vec();
|
||||||
|
int64_t last_dim = input.ndimension() - 1;
|
||||||
|
int64_t d = sizes[last_dim] / 2;
|
||||||
|
sizes[last_dim] = d;
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
at::Tensor out = at::empty(sizes, input.options());
|
||||||
|
|
||||||
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
|
||||||
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
|
||||||
|
act_and_mul_kernel_impl(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
num_tokens,
|
||||||
|
d,
|
||||||
|
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
|
||||||
|
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
|
||||||
|
});
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ limitations under the License.
|
|||||||
// silu_and_mul
|
// silu_and_mul
|
||||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||||
|
|
||||||
|
// gelu_and_mul
|
||||||
|
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
|
||||||
|
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
|
||||||
|
|
||||||
// l2norm
|
// l2norm
|
||||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||||
|
|
||||||
@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
// activation
|
// activation
|
||||||
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||||
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||||
|
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
|
||||||
|
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
|
||||||
|
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
|
||||||
|
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import unittest
|
|||||||
import sgl_kernel
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils import SiluAndMul, precision
|
from utils import GeluAndMul, SiluAndMul, precision
|
||||||
|
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase):
|
|||||||
N = [22016, 22018]
|
N = [22016, 22018]
|
||||||
dtype = [torch.float16, torch.bfloat16]
|
dtype = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
def _activation_test(self, m, n, dtype):
|
def _silu_and_mul_test(self, m, n, dtype):
|
||||||
x = torch.randn([m, n], dtype=dtype)
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
|
||||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||||
@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase):
|
|||||||
atol = rtol = precision[ref_out.dtype]
|
atol = rtol = precision[ref_out.dtype]
|
||||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def _gelu_and_mul_test(self, m, n, dtype):
|
||||||
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
|
||||||
|
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||||
|
ref_out = GeluAndMul(x, approximate="none")
|
||||||
|
|
||||||
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def _gelu_tanh_and_mul_test(self, m, n, dtype):
|
||||||
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
|
||||||
|
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||||
|
ref_out = GeluAndMul(x, approximate="tanh")
|
||||||
|
|
||||||
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
def test_activation(self):
|
def test_activation(self):
|
||||||
for params in itertools.product(self.M, self.N, self.dtype):
|
for params in itertools.product(self.M, self.N, self.dtype):
|
||||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||||
self._activation_test(*params)
|
self._silu_and_mul_test(*params)
|
||||||
|
self._gelu_and_mul_test(*params)
|
||||||
|
self._gelu_tanh_and_mul_test(*params)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return F.silu(x[..., :d]) * x[..., d:]
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
|
|
||||||
|
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||||
|
|
||||||
|
|
||||||
def per_token_quant_int8(x):
|
def per_token_quant_int8(x):
|
||||||
x = x.float()
|
x = x.float()
|
||||||
absmax = x.abs().max(dim=-1).values
|
absmax = x.abs().max(dim=-1).values
|
||||||
|
|||||||
Reference in New Issue
Block a user