Add CPU optimized kernels for topk and rope fusions (#6456)

This commit is contained in:
jianan-gu
2025-06-03 08:37:34 +08:00
committed by GitHub
parent ff91474825
commit ff00895c46
7 changed files with 829 additions and 98 deletions

View File

@@ -23,6 +23,9 @@ limitations under the License.
// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
// topk
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
@@ -185,8 +193,13 @@ void shm_allreduce(
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
// rope
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at::Tensor& positions,
at::Tensor& query,
at::Tensor& key,
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox);
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation
@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// norm
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
m.def(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group) -> (Tensor, Tensor)");
@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
m.def(
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)");
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
}
REGISTER_EXTENSION(common_ops)