From 817d43705cd7f54aa5256e29edcd865c66250a5a Mon Sep 17 00:00:00 2001 From: Shi Shuai <126407087+shuaills@users.noreply.github.com> Date: Thu, 13 Mar 2025 11:50:46 +0800 Subject: [PATCH] feat: support ep size < 32 for sgl kernel (#4348) --- .../benchmark/bench_moe_align_block_size.py | 83 ++++++++++++------- sgl-kernel/csrc/moe/moe_align_kernel.cu | 34 ++++++-- 2 files changed, 82 insertions(+), 35 deletions(-) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 1fb0bd342..4266acfb5 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -196,14 +196,21 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): expert_ids_triton, num_tokens_post_pad_triton, ) - ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids_vllm, - expert_ids_vllm, - num_tokens_post_pad_vllm, - ) + + try: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_vllm, + expert_ids_vllm, + num_tokens_post_pad_vllm, + ) + print(f"✅ VLLM implementation works with {num_experts} experts!") + vllm_works = True + except RuntimeError as e: + print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") + vllm_works = False if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( num_tokens_post_pad_cuda, num_tokens_post_pad_triton @@ -216,20 +223,26 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) - if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose( - num_tokens_post_pad_cuda, num_tokens_post_pad_vllm + if ( + vllm_works + and torch.allclose(expert_ids_cuda, expert_ids_vllm) + and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm) ): print("✅ SGL and VLLM implementations match") else: - print("❌ SGL and VLLM implementations do not match") - print("SGL expert_ids:", expert_ids_cuda) - print("VLLM expert_ids:", expert_ids_vllm) - print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) - print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) + if not vllm_works: + print("⚠️ VLLM comparison skipped due to failure") + else: + print("❌ SGL and VLLM implementations do not match") + print("SGL expert_ids:", expert_ids_cuda) + print("VLLM expert_ids:", expert_ids_vllm) + print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) + print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) +# Test range num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] -num_experts_range = [32, 64, 128, 256] +num_experts_range = [8, 32, 64, 128, 256] topk_range = [2, 4, 8] configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) @@ -316,17 +329,22 @@ def benchmark(num_tokens, num_experts, topk, provider): quantiles=quantiles, ) else: # vllm - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), - quantiles=quantiles, - ) + try: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + except RuntimeError as e: + print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}") + # Return extreme values to indicate failure in the chart + return float("inf"), float("inf"), float("inf") return 1000 * ms, 1000 * max_ms, 1000 * min_ms @@ -343,7 +361,7 @@ if __name__ == "__main__": "--num_experts", type=int, default=256, - choices=[8, 64, 128, 256], + choices=[8, 16, 32, 64, 128, 256], help="Number of experts for benchmark", ) parser.add_argument( @@ -353,8 +371,15 @@ if __name__ == "__main__": choices=[2, 4, 8], help="Top-k value for benchmark", ) + parser.add_argument( + "--skip_full_benchmark", + action="store_true", + help="Only run the calculate_diff function, skip full benchmarking", + ) args = parser.parse_args() calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) - benchmark.run(print_data=True) + if not args.skip_full_benchmark: + print(f"\n📊 Running performance benchmark for {args.num_experts} experts...") + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 83609a329..cfb7adca5 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, size_t numel, @@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel( const int my_expert_start = warp_id * experts_per_warp; for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < num_experts) { + if (my_expert_start + i < padded_num_experts) { shared_counts[warp_id * experts_per_warp + i] = 0; } } @@ -108,23 +109,44 @@ void moe_align_block_size( torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts % WARP_SIZE == 0); - int experts_per_warp = num_experts / WARP_SIZE; + + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + int experts_per_warp; + int threads; + + if (num_experts <= 8) { + experts_per_warp = 8; + threads = 256; + } else if (num_experts <= 16) { + experts_per_warp = 16; + threads = 512; + } else { + experts_per_warp = WARP_SIZE; + threads = 1024; + } + + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto align_kernel = moe_align_block_size_kernel; - size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t); - align_kernel<<<1, 1024, shared_mem_size, stream>>>( + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); + + align_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, + padded_num_experts, experts_per_warp, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); - const int block_threads = 256; + const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int max_blocks = 65535; const int actual_blocks = std::min(num_blocks, max_blocks);