diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 274502221..30eae0b9a 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -5,7 +5,11 @@ import torch import triton import triton.language as tl from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size -from vllm import _custom_ops as ops + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None USE_RANDOM_PERM = False @@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ) print(f"✅ VLLM implementation works with {num_experts} experts!") vllm_works = True - except RuntimeError as e: + except Exception as e: print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") vllm_works = False @@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: return topk_ids +def sgl_moe_align_block_size_with_empty( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + pad_sorted_token_ids=False, +): + if not pad_sorted_token_ids: + sorted_ids.fill_(topk_ids.numel()) + + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + token_cnts_buffer, + cumsum_buffer, + pad_sorted_token_ids, + ) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens", "num_experts", "topk"], x_vals=configs, line_arg="provider", - line_vals=["sgl", "triton", "vllm"], - line_names=["SGL", "Triton", "VLLM"], + line_vals=["sgl", "sgl_fusion", "triton"], + line_names=["SGL", "SGL Fusion", "Triton"], styles=[("blue", "-"), ("red", "-"), ("green", "-")], ylabel="us", plot_name="moe-align-block-size-performance", @@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider): sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device @@ -297,35 +334,6 @@ def benchmark(num_tokens, num_experts, topk, provider): quantiles = [0.5, 0.2, 0.8] if provider == "sgl": - - def sgl_moe_align_block_size_with_empty( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ): - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) - - sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - token_cnts_buffer, - cumsum_buffer, - ) - ms, min_ms, max_ms = triton.testing.do_bench( lambda: sgl_moe_align_block_size_with_empty( topk_ids, @@ -337,7 +345,21 @@ def benchmark(num_tokens, num_experts, topk, provider): ), quantiles=quantiles, ) + elif provider == "sgl_fusion": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_moe_align_block_size_with_empty( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + pad_sorted_token_ids=True, + ), + quantiles=quantiles, + ) elif provider == "triton": + sorted_ids.fill_(topk_ids.numel()) ms, min_ms, max_ms = triton.testing.do_bench( lambda: moe_align_block_size_triton( topk_ids, @@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider): ), quantiles=quantiles, ) - else: # vllm - try: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), - quantiles=quantiles, - ) - except RuntimeError as e: - print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}") - # Return extreme values to indicate failure in the chart - return float("inf"), float("inf"), float("inf") return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 11a9adbb4..324b22fb8 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " + "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.def( diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index e3abb8849..2d0061af8 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -21,8 +21,17 @@ limitations under the License. #include "utils.h" +template +class alignas(Alignment) AlignedArray { + public: + T data[N]; +}; + #define WARP_SIZE 32 +#define VEC_SIZE 4 +using Vec = AlignedArray; + template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, @@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel( int32_t experts_per_warp, int32_t block_size, size_t numel, - int32_t* __restrict__ cumsum) { + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids) { extern __shared__ int32_t shared_counts[]; const int warp_id = threadIdx.x / WARP_SIZE; @@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel( expert_ids[i / block_size] = threadIdx.x; } } + + if (pad_sorted_token_ids) { + int32_t fill_val = static_cast(numel); + int32_t total = *total_tokens_post_pad; + + Vec fill_vec; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + fill_vec.data[i] = fill_val; + } + + int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + + for (int32_t idx = tid; idx < total_vec_count; idx += stride) { + out_ptr[idx] = fill_vec; + } + } } template @@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t block_size, - size_t numel) { + size_t numel, + bool pad_sorted_token_ids) { const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } + if (pad_sorted_token_ids) { + int32_t fill_val = static_cast(numel); + int32_t total = *total_tokens_post_pad; + + Vec fill_vec; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + fill_vec.data[i] = fill_val; + } + + int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + + for (int32_t idx = tid; idx < total_vec_count; idx += stride) { + out_ptr[idx] = fill_vec; + } + } + + __syncthreads(); + for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; @@ -165,7 +214,8 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, - torch::Tensor cumsum_buffer) { + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; @@ -190,7 +240,8 @@ void moe_align_block_size( num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); + topk_ids.numel(), + pad_sorted_token_ids); } else { auto align_kernel = moe_align_block_size_kernel; @@ -207,7 +258,8 @@ void moe_align_block_size( experts_per_warp, block_size, topk_ids.numel(), - cumsum_buffer.data_ptr()); + cumsum_buffer.data_ptr(), + pad_sorted_token_ids); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 5f337c9da..0b1acf685 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { */ m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " + "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.def( diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index c90800f76..1b70bdda6 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -212,7 +212,8 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, - torch::Tensor cumsum_buffer); + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids); void topk_softmax( torch::Tensor& topk_weights, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 75fbc6b42..34d7518e4 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -12,6 +12,7 @@ def moe_align_block_size( num_tokens_post_pad, token_cnts_buffer, cumsum_buffer, + pad_sorted_token_ids=False, ): torch.ops.sgl_kernel.moe_align_block_size.default( topk_ids, @@ -22,6 +23,7 @@ def moe_align_block_size( num_tokens_post_pad, token_cnts_buffer, cumsum_buffer, + pad_sorted_token_ids, ) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 3baae0a2c..8d35e75c1 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -138,33 +138,32 @@ def moe_align_block_size_triton( @pytest.mark.parametrize( - "block_size,num_tokens,topk,num_experts", + "block_size,num_tokens,topk,num_experts,pad_sorted_token_ids", list( itertools.product( [32, 64, 128, 256], # block_size [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64], # topk [64, 160, 256, 257, 260, 264], # num_experts + [True, False], # pad_sorted_token_ids ) ), ) def test_moe_align_block_size_compare_implementations( - block_size, num_tokens, topk, num_experts + block_size, num_tokens, topk, num_experts, pad_sorted_token_ids ): - topk_ids = torch.stack( - [ - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] - for _ in range(num_tokens) - ] - ) + topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[ + :, :topk + ] max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids_cuda = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids_cuda.fill_(topk_ids.numel()) + if not pad_sorted_token_ids: + sorted_ids_cuda.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size expert_ids_cuda = torch.zeros( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device @@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations( num_tokens_post_pad_cuda, token_cnts_buffer, cumsum_buffer, + pad_sorted_token_ids, ) moe_align_block_size_triton( @@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations( num_tokens_post_pad_triton, ) - assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( + assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), ( f"Expert IDs mismatch for block_size={block_size}, " f"num_tokens={num_tokens}, topk={topk}\n" f"CUDA expert_ids: {expert_ids_cuda}\n" f"Triton expert_ids: {expert_ids_triton}" ) - assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( + assert torch.allclose( + num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0 + ), ( f"Num tokens post pad mismatch for block_size={block_size}, " f"num_tokens={num_tokens}, topk={topk}\n" f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}" ) + # Select an expert to check + expert_idx = expert_ids_cuda.max().item() + + # Get the first and last block id where expert_ids_cuda == expert_idx + matching_indices = torch.where(expert_ids_cuda == expert_idx)[0] + block_sorted_start = matching_indices[0].item() * block_size + block_sorted_end = min( + (matching_indices[-1].item() + 1) * block_size, max_num_tokens_padded + ) + + selected_sorted_ids_cuda = sorted_ids_cuda[ + block_sorted_start:block_sorted_end + ].sort()[0] + selected_sorted_ids_triton = sorted_ids_triton[ + block_sorted_start:block_sorted_end + ].sort()[0] + + assert torch.allclose( + selected_sorted_ids_cuda, + selected_sorted_ids_triton, + atol=0, + rtol=0, + ), ( + f"Sorted IDs mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n" + f"Triton sorted_ids: {selected_sorted_ids_triton}" + ) + if __name__ == "__main__": pytest.main([__file__])