[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)
Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
@@ -47,6 +47,45 @@ namespace {
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch with mixed dtypes (TYPE1, TYPE2):
|
||||
// TYPE1: the primary dtype (input, output, weight);
|
||||
// TYPE2: the secondary dtype (bias, etc.).
|
||||
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
|
||||
[&] { \
|
||||
if (TYPE2 == at::kFloat) { \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(TYPE1 == TYPE2); \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
@@ -252,29 +252,33 @@ void topk_softmax_kernel_impl(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
template <typename scalar_t, typename param_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
for (int d = 0; d < SIZE; d += bVec::size()) {
|
||||
bVec bias_vec = bVec::loadu(bias + d);
|
||||
fVec bias0, bias1;
|
||||
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
|
||||
|
||||
fVec x0 = fVec::loadu(scores + d) + bias0;
|
||||
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
auto vec_size = bVec::size();
|
||||
int d = 0;
|
||||
for (; d <= SIZE - vec_size; d += vec_size) {
|
||||
fVec bias0, bias1, x0, x1;
|
||||
std::tie(bias0, bias1) = load_float_vec2(bias + d);
|
||||
std::tie(x0, x1) = load_float_vec2(scores + d);
|
||||
x0 = x0 + bias0;
|
||||
x1 = x1 + bias1;
|
||||
x0.store(scores2 + d);
|
||||
x1.store(scores2 + d + fVec::size());
|
||||
}
|
||||
for (; d < SIZE; d++) {
|
||||
scores2[d] = scores[d] + (float)bias[d];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
|
||||
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
|
||||
void biased_grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
const scalar_t* __restrict__ bias,
|
||||
const param_t* __restrict__ bias,
|
||||
int64_t num_tokens,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
@@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl(
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do sigmoid to get scores
|
||||
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
// find the max
|
||||
@@ -406,15 +411,15 @@ void biased_grouped_topk_kernel_impl(
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<param_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
} // anonymous namespace
|
||||
@@ -635,7 +640,6 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
CHECK_EQ(correction_bias.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
@@ -644,8 +648,7 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
|
||||
// NOW only support DSv3 configs
|
||||
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
|
||||
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
|
||||
@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
// allow f16, bf16
|
||||
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
bVec x_vec = bVec::loadu(data);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
// allow f32
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec x0 = fVec::loadu(data);
|
||||
fVec x1 = fVec::loadu(data + fVec::size());
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
|
||||
@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase):
|
||||
|
||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
||||
class TestBiasedGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||
def _run_single_test(
|
||||
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
|
||||
):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
correction_bias = torch.randn(E, dtype=dtype)
|
||||
correction_bias = torch.randn(E, dtype=bias_dtype)
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
||||
hidden_states.float(),
|
||||
@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
|
||||
def test_biased_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||
for bias_dtype in [torch.float32, torch.bfloat16]:
|
||||
self._run_single_test(
|
||||
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
|
||||
)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
|
||||
Reference in New Issue
Block a user