[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
Chunyuan WU
2025-08-05 09:28:31 +08:00
committed by GitHub
parent d4bf5a8524
commit 08f8f49016
4 changed files with 94 additions and 28 deletions

View File

@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl(
});
}
template <typename scalar_t, int SIZE>
template <typename scalar_t, typename param_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
using bVec = at::vec::Vectorized<scalar_t>;
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
using fVec = at::vec::Vectorized<float>;
for (int d = 0; d < SIZE; d += bVec::size()) {
bVec bias_vec = bVec::loadu(bias + d);
fVec bias0, bias1;
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
fVec x0 = fVec::loadu(scores + d) + bias0;
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
using bVec = at::vec::Vectorized<scalar_t>;
auto vec_size = bVec::size();
int d = 0;
for (; d <= SIZE - vec_size; d += vec_size) {
fVec bias0, bias1, x0, x1;
std::tie(bias0, bias1) = load_float_vec2(bias + d);
std::tie(x0, x1) = load_float_vec2(scores + d);
x0 = x0 + bias0;
x1 = x1 + bias1;
x0.store(scores2 + d);
x1.store(scores2 + d + fVec::size());
}
for (; d < SIZE; d++) {
scores2[d] = scores[d] + (float)bias[d];
}
}
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
const scalar_t* __restrict__ bias,
const param_t* __restrict__ bias,
int64_t num_tokens,
int64_t num_groups,
int64_t topk_group,
@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl(
for (int64_t i = begin; i < end; ++i) {
// do sigmoid to get scores
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
for (int64_t g = 0; g < num_groups; ++g) {
// find the max
@@ -406,15 +411,15 @@ void biased_grouped_topk_kernel_impl(
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<scalar_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<param_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
renormalize);
} // anonymous namespace
@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
CHECK_EQ(correction_bias.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
// NOW only support DSv3 configs
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
switch (num_experts) {
case 256: