Add CPU optimized kernels for topk and rope fusions (#6456)

This commit is contained in:
jianan-gu
2025-06-03 08:37:34 +08:00
committed by GitHub
parent ff91474825
commit ff00895c46
7 changed files with 829 additions and 98 deletions

View File

@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
}
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_sigmoid_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
using Vec = at::vec::Vectorized<float>;
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
float gmax = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
// find position of first max,
// note that we may have multiple max values.
int first_max_idx = -1;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
if (scores[e] == gmax) {
first_max_idx = e;
break;
}
}
// scalar sigmoid
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
topk_ids[i] = first_max_idx;
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_softmax_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
for (int64_t e = 0; e < num_experts_per_group; ++e) {
queue[e] = {scores[e], e};
}
std::partial_sort(
queue.begin(),
queue.begin() + num_experts_per_group,
queue.end(),
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] = queue[j].first;
topk_ids[i * topk + j] = queue[j].second;
}
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
topk_group, \
renormalize);
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
topk_sigmoid_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
topk_softmax_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SIGMOID_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SIGMOID_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SIGMOID_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SIGMOID_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SIGMOID_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SIGMOID_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SIGMOID_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SIGMOID_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SIGMOID_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SIGMOID_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
// grouped topk for DeepSeek V2
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,