diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 30eae0b9a..ed8a7b8f3 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -164,9 +164,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): num_tokens_post_pad_cuda = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.zeros( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) cumsum_buffer = torch.zeros( num_experts + 1, dtype=torch.int32, device=topk_ids.device ) @@ -189,7 +186,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): sorted_ids_cuda, expert_ids_cuda, num_tokens_post_pad_cuda, - token_cnts_buffer, cumsum_buffer, ) moe_align_block_size_triton( @@ -273,11 +269,6 @@ def sgl_moe_align_block_size_with_empty( if not pad_sorted_token_ids: sorted_ids.fill_(topk_ids.numel()) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) cumsum_buffer = torch.empty( num_experts + 1, dtype=torch.int32, device=topk_ids.device ) @@ -289,7 +280,6 @@ def sgl_moe_align_block_size_with_empty( sorted_ids.clone(), expert_ids.clone(), num_tokens_post_pad.clone(), - token_cnts_buffer, cumsum_buffer, pad_sorted_token_ids, ) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 623fbefb5..295939900 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -165,7 +165,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool " "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index ea17b329c..19d0cc7a9 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel( const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; + int32_t expert_id = topk_ids[i] + 1; int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); for (size_t i = tid; i < numel; i += stride) { - int expert_id = topk_ids[i]; + int expert_id = topk_ids[i] + 1; atomicAdd(&shared_counts[expert_id], 1); } @@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel( right = mid; } } - expert_ids[i] = left - 1; + expert_ids[i] = left - 2; } if (pad_sorted_token_ids) { @@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } for (size_t i = tid; i < numel; i += stride) { - ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; } __syncthreads(); @@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + expert_ids[i / block_size] = threadIdx.x - 1; } } @@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( __syncthreads(); for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; + int32_t expert_id = topk_ids[i] + 1; int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[threadIdx.x * num_experts + expert_id]; @@ -308,7 +308,6 @@ void moe_align_block_size( torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer, bool pad_sorted_token_ids) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 9010d0b26..aaf474fb2 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -92,7 +92,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { */ m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool " "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index ca8276050..fa6de7362 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -230,7 +230,6 @@ void moe_align_block_size( torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer, bool pad_sorted_token_ids); diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index ab7e1702a..c16a2b6fe 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -10,7 +10,6 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, - token_cnts_buffer, cumsum_buffer, pad_sorted_token_ids=False, ): @@ -21,7 +20,6 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, - token_cnts_buffer, cumsum_buffer, pad_sorted_token_ids, ) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 550c7a1ad..90f04ec95 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations( :, :topk ] - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) sorted_ids_cuda = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device @@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations( num_tokens_post_pad_cuda = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device + num_experts + 2, dtype=torch.int32, device=topk_ids.device ) sorted_ids_triton = torch.empty_like(sorted_ids_cuda) @@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations( moe_align_block_size( topk_ids, - num_experts, + num_experts + 1, block_size, sorted_ids_cuda, expert_ids_cuda, num_tokens_post_pad_cuda, - token_cnts_buffer, cumsum_buffer, pad_sorted_token_ids, ) moe_align_block_size_triton( topk_ids, - num_experts, + num_experts + 1, block_size, sorted_ids_triton, expert_ids_triton,