[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)
Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
template <typename scalar_t, typename param_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
for (int d = 0; d < SIZE; d += bVec::size()) {
|
||||
bVec bias_vec = bVec::loadu(bias + d);
|
||||
fVec bias0, bias1;
|
||||
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
|
||||
|
||||
fVec x0 = fVec::loadu(scores + d) + bias0;
|
||||
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
auto vec_size = bVec::size();
|
||||
int d = 0;
|
||||
for (; d <= SIZE - vec_size; d += vec_size) {
|
||||
fVec bias0, bias1, x0, x1;
|
||||
std::tie(bias0, bias1) = load_float_vec2(bias + d);
|
||||
std::tie(x0, x1) = load_float_vec2(scores + d);
|
||||
x0 = x0 + bias0;
|
||||
x1 = x1 + bias1;
|
||||
x0.store(scores2 + d);
|
||||
x1.store(scores2 + d + fVec::size());
|
||||
}
|
||||
for (; d < SIZE; d++) {
|
||||
scores2[d] = scores[d] + (float)bias[d];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
|
||||
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
|
||||
void biased_grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
const scalar_t* __restrict__ bias,
|
||||
const param_t* __restrict__ bias,
|
||||
int64_t num_tokens,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl(
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do sigmoid to get scores
|
||||
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
// find the max
|
||||
@@ -406,15 +411,15 @@ void biased_grouped_topk_kernel_impl(
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<param_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
} // anonymous namespace
|
||||
@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
CHECK_EQ(correction_bias.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
|
||||
// NOW only support DSv3 configs
|
||||
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
|
||||
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
|
||||
Reference in New Issue
Block a user