Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_sigmoid_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
|
||||
|
||||
float gmax = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
|
||||
|
||||
// find position of first max,
|
||||
// note that we may have multiple max values.
|
||||
int first_max_idx = -1;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
if (scores[e] == gmax) {
|
||||
first_max_idx = e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// scalar sigmoid
|
||||
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
|
||||
topk_ids[i] = first_max_idx;
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_softmax_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
queue[e] = {scores[e], e};
|
||||
}
|
||||
|
||||
std::partial_sort(
|
||||
queue.begin(),
|
||||
queue.begin() + num_experts_per_group,
|
||||
queue.end(),
|
||||
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
|
||||
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] = queue[j].first;
|
||||
topk_ids[i * topk + j] = queue[j].second;
|
||||
}
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||
@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
|
||||
topk_sigmoid_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
|
||||
topk_softmax_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
// grouped topk for DeepSeek V2
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user