[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)
This commit is contained in:
@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[sqrt_2_div_pi](float x) {
|
||||
float x3 = x * x * x;
|
||||
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
|
||||
return 0.5f * x * (1.f + std::tanh(tanh_arg));
|
||||
},
|
||||
[sqrt_2_div_pi](Vec x) {
|
||||
Vec x3 = x * x * x;
|
||||
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
|
||||
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
|
||||
});
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
|
||||
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user