diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 6f09a0922..1bf45ee4b 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -47,6 +47,45 @@ namespace { } \ }() +// dispatch with mixed dtypes (TYPE1, TYPE2): +// TYPE1: the primary dtype (input, output, weight); +// TYPE2: the secondary dtype (bias, etc.). +#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \ + [&] { \ + if (TYPE2 == at::kFloat) { \ + switch (TYPE1) { \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + using param_t = float; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + using param_t = float; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + } else { \ + TORCH_CHECK(TYPE1 == TYPE2); \ + switch (TYPE1) { \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + using param_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + using param_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + } \ + }() + #define UNUSED(x) (void)(x) #define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index cdfa4c271..abc5a34fa 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -252,29 +252,33 @@ void topk_softmax_kernel_impl( }); } -template +template inline void -apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) { - using bVec = at::vec::Vectorized; +apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) { using fVec = at::vec::Vectorized; - for (int d = 0; d < SIZE; d += bVec::size()) { - bVec bias_vec = bVec::loadu(bias + d); - fVec bias0, bias1; - std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec); - - fVec x0 = fVec::loadu(scores + d) + bias0; - fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1; + using bVec = at::vec::Vectorized; + auto vec_size = bVec::size(); + int d = 0; + for (; d <= SIZE - vec_size; d += vec_size) { + fVec bias0, bias1, x0, x1; + std::tie(bias0, bias1) = load_float_vec2(bias + d); + std::tie(x0, x1) = load_float_vec2(scores + d); + x0 = x0 + bias0; + x1 = x1 + bias1; x0.store(scores2 + d); x1.store(scores2 + d + fVec::size()); } + for (; d < SIZE; d++) { + scores2[d] = scores[d] + (float)bias[d]; + } } -template +template void biased_grouped_topk_kernel_impl( float* __restrict__ topk_weights, int32_t* __restrict__ topk_ids, const scalar_t* __restrict__ gating_output, - const scalar_t* __restrict__ bias, + const param_t* __restrict__ bias, int64_t num_tokens, int64_t num_groups, int64_t topk_group, @@ -295,7 +299,8 @@ void biased_grouped_topk_kernel_impl( for (int64_t i = begin; i < end; ++i) { // do sigmoid to get scores sigmoid(scores, gating_output + i * NUM_EXPERTS); - apply_bias(scores2, scores, bias); + + apply_bias(scores2, scores, bias); for (int64_t g = 0; g < num_groups; ++g) { // find the max @@ -406,15 +411,15 @@ void biased_grouped_topk_kernel_impl( topk, \ renormalize); -#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \ - biased_grouped_topk_kernel_impl( \ - topk_weights.data_ptr(), \ - topk_ids.data_ptr(), \ - gating_output.data_ptr(), \ - correction_bias.data_ptr(), \ - num_tokens, \ - num_expert_group, \ - topk_group, \ +#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \ + biased_grouped_topk_kernel_impl( \ + topk_weights.data_ptr(), \ + topk_ids.data_ptr(), \ + gating_output.data_ptr(), \ + correction_bias.data_ptr(), \ + num_tokens, \ + num_expert_group, \ + topk_group, \ renormalize); } // anonymous namespace @@ -635,7 +640,6 @@ std::tuple biased_grouped_topk_cpu( const auto st = hidden_states.scalar_type(); CHECK_EQ(gating_output.scalar_type(), st); - CHECK_EQ(correction_bias.scalar_type(), st); int64_t num_tokens = hidden_states.size(0); int64_t num_experts = gating_output.size(1); @@ -644,8 +648,7 @@ std::tuple biased_grouped_topk_cpu( at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); - AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] { - // NOW only support DSv3 configs + CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] { TORCH_CHECK(topk == 8, "Unexpected topk: ", topk); switch (num_experts) { case 256: diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index 9f8eaad18..d28124c1d 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -16,6 +16,25 @@ inline Vectorized convert_from_float_ext(const Vectorized& a, c return at::vec::convert_from_float(a, b); } +// allow f16, bf16 +template , int> = 1> +inline std::tuple, Vectorized> load_float_vec2(const scalar_t* __restrict__ data) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + bVec x_vec = bVec::loadu(data); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_vec); + return std::make_tuple(x0, x1); +} + +// allow f32 +inline std::tuple, Vectorized> load_float_vec2(const float* __restrict__ data) { + using fVec = at::vec::Vectorized; + fVec x0 = fVec::loadu(data); + fVec x1 = fVec::loadu(data + fVec::size()); + return std::make_tuple(x0, x1); +} + #if defined(CPU_CAPABILITY_AVX512) // `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics diff --git a/test/srt/cpu/test_topk.py b/test/srt/cpu/test_topk.py index 0e0aeef2c..4b4ce21ae 100644 --- a/test/srt/cpu/test_topk.py +++ b/test/srt/cpu/test_topk.py @@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase): # DeepSeek V2/V3/R1 uses biased_grouped_top class TestBiasedGroupedTopK(CustomTestCase): - def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): + def _run_single_test( + self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype + ): torch.manual_seed(1234) # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating hidden_states = torch.randn(M, 100, dtype=dtype) gating_output = torch.randn(M, E, dtype=dtype) * 2 * M - correction_bias = torch.randn(E, dtype=dtype) + correction_bias = torch.randn(E, dtype=bias_dtype) ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( hidden_states.float(), @@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase): def test_biased_grouped_topk(self): for renormalize in [True, False]: - self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) + for bias_dtype in [torch.float32, torch.bfloat16]: + self._run_single_test( + 122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype + ) class TestTopK(CustomTestCase):