[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)
This commit is contained in:
@@ -23,6 +23,10 @@ limitations under the License.
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// gelu_and_mul
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
|
||||
|
||||
// l2norm
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||
|
||||
@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
|
||||
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
|
||||
|
||||
// norm
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
|
||||
Reference in New Issue
Block a user