diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 8418334b9..c90be69e6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -702,7 +702,7 @@ def moe_align_block_size( num_tokens_post_pad, ) else: - token_cnts_buffer = torch.zeros( + token_cnts_buffer = torch.empty( (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device, diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 4266acfb5..274502221 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): # Test range -num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] num_experts_range = [8, 32, 64, 128, 256] -topk_range = [2, 4, 8] +topk_range = [1, 2, 4, 8] configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) @@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider): (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - token_cnts_buffer = torch.zeros( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) - cumsum_buffer = torch.zeros( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) quantiles = [0.5, 0.2, 0.8] if provider == "sgl": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: sgl_moe_align_block_size( + + def sgl_moe_align_block_size_with_empty( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ): + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + sgl_moe_align_block_size( topk_ids, num_experts, block_size, @@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider): num_tokens_post_pad.clone(), token_cnts_buffer, cumsum_buffer, + ) + + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_moe_align_block_size_with_empty( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, ), quantiles=quantiles, ) diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 6ffa73924..d44eff5c1 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; @@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel( } } +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) { + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + tokens_cnts[threadIdx.x] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x]; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[threadIdx.x * num_experts + expert_id]; + } +} + void moe_align_block_size( torch::Tensor topk_ids, int64_t num_experts, @@ -111,50 +170,58 @@ void moe_align_block_size( int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - int experts_per_warp; - int threads; - - if (num_experts <= 8) { - experts_per_warp = 8; - threads = 256; - } else if (num_experts <= 16) { - experts_per_warp = 16; - threads = 512; - } else { - experts_per_warp = WARP_SIZE; - threads = 1024; - } + int experts_per_warp = WARP_SIZE; + int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto align_kernel = moe_align_block_size_kernel; + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - align_kernel<<<1, threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - topk_ids.numel(), - cumsum_buffer.data_ptr()); + auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel()); + } else { + auto align_kernel = moe_align_block_size_kernel; - const int block_threads = std::min(256, (int)threads); - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - auto sort_kernel = count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), - topk_ids.numel()); + cumsum_buffer.zero_(); + + align_kernel<<<1, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr()); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); + } }); } diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index fb7c4c640..3baae0a2c 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -151,7 +151,6 @@ def moe_align_block_size_triton( def test_moe_align_block_size_compare_implementations( block_size, num_tokens, topk, num_experts ): - # For DeepSeek V3, we have 256 experts topk_ids = torch.stack( [