From 2998c4bdf4fea00081978b320b60f4b0ce905297 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Fri, 4 Jul 2025 03:42:44 +0800 Subject: [PATCH] [optimize] fuse renormalize into moe_topk_softmax (#7744) Co-authored-by: ispobock --- .../benchmark/bench_moe_topk_softmax.py | 4 - sgl-kernel/csrc/common_extension.cc | 4 +- .../csrc/moe/moe_topk_softmax_kernels.cu | 238 ++++++++++++------ sgl-kernel/csrc/torch_extension_rocm.cc | 4 +- sgl-kernel/include/sgl_kernel_ops.h | 5 +- sgl-kernel/python/sgl_kernel/moe.py | 4 +- sgl-kernel/tests/test_moe_topk_softmax.py | 96 ++++++- 7 files changed, 254 insertions(+), 101 deletions(-) diff --git a/sgl-kernel/benchmark/bench_moe_topk_softmax.py b/sgl-kernel/benchmark/bench_moe_topk_softmax.py index 5598cfbec..1d3e3e93f 100644 --- a/sgl-kernel/benchmark/bench_moe_topk_softmax.py +++ b/sgl-kernel/benchmark/bench_moe_topk_softmax.py @@ -34,14 +34,10 @@ def sglang_topk_softmax(gating_output, topk): topk_indices = torch.empty( (num_tokens, topk), dtype=torch.int32, device=gating_output.device ) - token_expert_indices = torch.empty( - (num_tokens, topk), dtype=torch.int32, device=gating_output.device - ) topk_softmax( topk_weights=topk_weights, topk_ids=topk_indices, - token_expert_indices=token_expert_indices, gating_output=gating_output, ) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 0aee1c1bf..b6a22152a 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - m.def( - "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); + m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.def( diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu index ac78ebb12..050e8d52b 100644 --- a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu @@ -41,15 +41,29 @@ template < /// Alignment requirement in bytes int Alignment = sizeof(T) * N> class alignas(Alignment) AlignedArray { - float data[N]; + T data[N]; }; +// ========================== Util functions to convert types ========================== +template +__device__ float convert_to_float(T x) { + if constexpr (std::is_same_v) { + return __half2float(x); + } else if constexpr (std::is_same_v) { + return __bfloat162float(x); + } else if constexpr (std::is_same_v) { + return x; + } else { + return static_cast(x); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. -template +template __launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) { + void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); + threadData = max(convert_to_float(input[idx]), threadData); } const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); @@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); + threadData += exp((convert_to_float(input[idx]) - float_max)); } const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); @@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + const float val = exp((convert_to_float(input[idx]) - float_max)) * normalizing_factor; output[idx] = val; } } @@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK( const bool* finished, float* output, int* indices, - int* source_rows, const int num_experts, const int k, const int start_expert, - const int end_expert) { + const int end_expert, + const bool renormalize) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK( cub_kvp thread_kvp; cub::ArgMax arg_max; - const int num_rows = gridDim.x; const int block_row = blockIdx.x; const bool row_is_active = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float row_sum_for_renormalize = 0; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = -1.f; // This is OK because inputs are probabilities @@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK( output[idx] = result_kvp.value; indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); - source_rows[idx] = k_idx * num_rows + block_row; + row_sum_for_renormalize += result_kvp.value; } __syncthreads(); } + + if (renormalize && threadIdx.x == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } } // ====================== TopK softmax things =============================== @@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( - const float* input, + const T* input, const bool* finished, float* output, const int num_rows, int* indices, - int* source_rows, const int k, const int start_expert, - const int end_expert) { + const int end_expert, + const bool renormalize) { // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); @@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; @@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the first column to start loads. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, // this can support all powers of 2 up to 16. // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. - using AccessType = AlignedArray; + using AccessType = AlignedArray; // Finally, we pull in the data from global mem - float row_chunk[VPT]; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + T row_chunk_temp[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_temp); const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; } + float row_chunk[VPT]; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = convert_to_float(row_chunk_temp[ii]); + } + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just // convert to float afterwards for the exp + sum reduction. float thread_max = row_chunk[0]; @@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float row_sum_for_renormalize = 0; + for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax float max_val = row_chunk[0]; @@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( const int idx = k * thread_row + k_idx; output[idx] = max_val; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; - source_rows[idx] = k_idx * num_rows + thread_row; + row_sum_for_renormalize += max_val; } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( } } } + + // Fuse renormalization of topk_weights into this kernel + if (renormalize && thread_group_idx == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; @@ -377,113 +417,120 @@ struct TopkConstants { }; } // namespace detail -template +template void topkGatingSoftmaxLauncherHelper( - const float* input, + const T* input, const bool* finished, float* output, int* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, + const bool renormalize, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize); } -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, \ - nullptr, \ - topk_weights, \ - topk_indices, \ - token_expert_indices, \ - num_tokens, \ - topk, \ - 0, \ - num_experts, \ - stream); +#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream); +template void topkGatingSoftmaxKernelLauncher( - const float* gating_output, + const T* gating_output, float* topk_weights, int* topk_indices, - int* token_expert_indices, float* softmax_workspace, const int num_tokens, const int num_experts, const int topk, + const bool renormalize, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; switch (num_experts) { case 1: - LAUNCH_SOFTMAX(1, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB); break; case 2: - LAUNCH_SOFTMAX(2, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB); break; case 4: - LAUNCH_SOFTMAX(4, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB); break; case 8: - LAUNCH_SOFTMAX(8, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB); break; case 16: - LAUNCH_SOFTMAX(16, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB); break; case 32: - LAUNCH_SOFTMAX(32, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB); break; case 64: - LAUNCH_SOFTMAX(64, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB); break; case 128: - LAUNCH_SOFTMAX(128, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB); break; case 256: - LAUNCH_SOFTMAX(256, WARPS_PER_TB); + LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB); break; default: { TORCH_CHECK( softmax_workspace != nullptr, "softmax_workspace must be provided for num_experts that are not a power of 2."); static constexpr int TPB = 256; - moeSoftmax<<>>(gating_output, nullptr, softmax_workspace, num_experts); + moeSoftmax<<>>(gating_output, nullptr, softmax_workspace, num_experts); moeTopK<<>>( - softmax_workspace, - nullptr, - topk_weights, - topk_indices, - token_expert_indices, - num_experts, - topk, - 0, - num_experts); + softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize); } } } void topk_softmax( - torch::Tensor& topk_weights, // [num_tokens, topk] - torch::Tensor& topk_indices, // [num_tokens, topk] - torch::Tensor& token_expert_indices, // [num_tokens, topk] - torch::Tensor& gating_output) // [num_tokens, num_experts] + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& gating_output, + const bool renormalize) // [num_tokens, num_experts] { - const int num_experts = gating_output.size(-1); - const int num_tokens = gating_output.numel() / num_experts; - const int topk = topk_weights.size(-1); + // Check data type + TORCH_CHECK( + gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half || + gating_output.scalar_type() == at::ScalarType::BFloat16, + "gating_output must be float32, float16, or bfloat16"); + + // Check dimensions + TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]"); + TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]"); + TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]"); + + // Check shapes + TORCH_CHECK( + gating_output.size(0) == topk_weights.size(0), + "First dimension of topk_weights must match num_tokens in gating_output"); + TORCH_CHECK( + gating_output.size(0) == topk_indices.size(0), + "First dimension of topk_indices must match num_tokens in gating_output"); + TORCH_CHECK( + topk_weights.size(-1) == topk_indices.size(-1), + "Second dimension of topk_indices must match topk in topk_weights"); + TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts"); + + const int num_experts = static_cast(gating_output.size(-1)); + const int num_tokens = static_cast(gating_output.size(0)); + const int topk = static_cast(topk_weights.size(-1)); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); const bool needs_workspace = !is_pow_2 || num_experts > 256; @@ -491,15 +538,44 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + torch::Tensor softmax_workspace = + torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float)); + + const at::ScalarType dtype = gating_output.scalar_type(); + if (dtype == at::ScalarType::Float) { + topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else if (dtype == at::ScalarType::Half) { + topkGatingSoftmaxKernelLauncher<__half>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else if (dtype == at::ScalarType::BFloat16) { + topkGatingSoftmaxKernelLauncher<__nv_bfloat16>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype); + } } diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 0b1acf685..0e3f48e61 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -63,9 +63,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - m.def( - "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); + m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); /* diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index d836f848a..c53ecdc01 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -222,10 +222,7 @@ void moe_align_block_size( bool pad_sorted_token_ids); void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); + torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize); std::vector moe_fused_gate( at::Tensor& input, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 34d7518e4..ab7e1702a 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -30,11 +30,11 @@ def moe_align_block_size( def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, gating_output: float, + renormalize: bool = False, ) -> None: torch.ops.sgl_kernel.topk_softmax.default( - topk_weights, topk_ids, token_expert_indices, gating_output + topk_weights, topk_ids, gating_output, renormalize ) diff --git a/sgl-kernel/tests/test_moe_topk_softmax.py b/sgl-kernel/tests/test_moe_topk_softmax.py index 420a3a6d6..b9a802c51 100644 --- a/sgl-kernel/tests/test_moe_topk_softmax.py +++ b/sgl-kernel/tests/test_moe_topk_softmax.py @@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk): topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda") topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") - token_expert_indices = torch.empty( - (num_tokens, topk), dtype=torch.int32, device="cuda" - ) topk_softmax( topk_weights, topk_indices, - token_expert_indices, gating_output, ) @@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk): ), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}" +@pytest.mark.parametrize( + "num_tokens, num_experts, topk, dtype", + list( + itertools.product( + [1, 16, 128, 512, 1024, 2048], # num_tokens + [4, 8, 16, 32, 64, 128, 256], # num_experts + [1, 2, 4], # topk + [torch.float16, torch.bfloat16, torch.float32], # dtype + ) + ), +) +def test_topk_softmax_dtype_regression(num_tokens, num_experts, topk, dtype): + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + + topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda") + topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") + + topk_softmax( + topk_weights, + topk_indices, + gating_output, + ) + + topk_weights_ref = torch.empty( + (num_tokens, topk), dtype=torch.float32, device="cuda" + ) + topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") + + topk_softmax( + topk_weights_ref, + topk_indices_ref, + gating_output.float(), + ) + + assert torch.allclose( + topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3 + ), f"Weights mismatch: SGLang old interface={topk_indices_ref} vs SGLang new interface={topk_weights}" + + assert torch.allclose( + topk_indices_ref.int(), topk_indices, atol=0, rtol=0 + ), f"Indices mismatch: SGLang old interface={topk_indices_ref}, SGLang new interface={topk_indices}" + + +@pytest.mark.parametrize( + "num_tokens, num_experts, topk", + list( + itertools.product( + [1, 16, 128, 512, 1024, 2048], # num_tokens + [4, 8, 16, 32, 64, 128, 256], # num_experts + [1, 2, 4], # topk + ) + ), +) +def test_topk_softmax_renormalize(num_tokens, num_experts, topk): + gating_output = torch.randn( + (num_tokens, num_experts), dtype=torch.bfloat16, device="cuda" + ) + + topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda") + topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") + + topk_softmax( + topk_weights, + topk_indices, + gating_output, + renormalize=True, + ) + + topk_weights_ref = torch.empty( + (num_tokens, topk), dtype=torch.float32, device="cuda" + ) + topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") + token_expert_indices_ref = torch.empty( + (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + topk_softmax( + topk_weights_ref, + topk_indices_ref, + gating_output, + ) + topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(dim=-1, keepdim=True) + + assert torch.allclose( + topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3 + ), f"Weights mismatch: SGLang w/o fused renormalize={topk_indices_ref} vs SGLang w/ fused renormalize={topk_weights}" + + assert torch.allclose( + topk_indices_ref.int(), topk_indices, atol=0, rtol=0 + ), f"Indices mismatch: SGLang w/o fused renormalize={topk_indices_ref}, SGLang w/ fused renormalize={topk_indices}" + + if __name__ == "__main__": pytest.main([__file__])