Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -23,6 +23,9 @@ limitations under the License.
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// l2norm
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||
|
||||
// rmsnorm
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
|
||||
@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
||||
|
||||
// topk
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
@@ -185,8 +193,13 @@ void shm_allreduce(
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
at::Tensor& positions,
|
||||
at::Tensor& query,
|
||||
at::Tensor& key,
|
||||
int64_t head_size,
|
||||
at::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// norm
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
|
||||
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
|
||||
m.def(
|
||||
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||
"int topk_group) -> (Tensor, Tensor)");
|
||||
@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
// rope
|
||||
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
|
||||
m.def(
|
||||
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
|
||||
"bool is_neox) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user