From ad3499858e34da72396087b496e6fee6c0de2781 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:42:36 +0800 Subject: [PATCH] clean moe align block kernel code and add acc test (#3332) --- .../benchmark_deepseekv3_moe_align_blocks.py | 2 +- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 40 +-- sgl-kernel/src/sgl-kernel/include/utils.h | 12 + sgl-kernel/tests/test_moe_align.py | 263 ++++++++++++++---- 4 files changed, 237 insertions(+), 80 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index e2c4d8d35..07b3acaaf 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -310,4 +310,4 @@ if __name__ == "__main__": calculate_diff(batch_size=4, seq_len=1024) - benchmark.run(print_data=True, save_path=args.save_path) + benchmark.run(print_data=True) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index d51ca5175..415fea8cc 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu - #include #include #include @@ -22,32 +20,15 @@ limitations under the License. #include +#include "utils.h" + #define WARP_SIZE 32 -#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ - cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) - -#define CEILDIV(x, y) (((x) + (y)-1) / (y)) - -#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) - -#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) - -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - return row * total_col + col; -} - template __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel, int32_t* cumsum) { - __shared__ int32_t shared_counts[32][8]; + __shared__ int32_t shared_counts[WARP_SIZE][8]; __shared__ int32_t local_offsets[256]; const int warp_id = threadIdx.x / WARP_SIZE; @@ -96,6 +77,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int __syncthreads(); + // Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here. + // If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this + // kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used, + // illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version + // results in the same issue, and a correct solution has not yet been found. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { int32_t expert_id = topk_ids[i]; int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); @@ -107,10 +93,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto kernel = moe_align_block_size_kernel; - kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + auto align_kernel = moe_align_block_size_kernel; + align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); }); } diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index b714df775..129402112 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -79,3 +79,15 @@ inline int getSMVersion() { return false; \ } \ }() + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y)-1) / (y)) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 2fca90b2f..81d05ffa1 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -1,67 +1,224 @@ +import itertools + +import pytest import torch +import triton +import triton.language as tl from sgl_kernel import moe_align_block_size -def test_moe_align_block_size(): +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = pid * tokens_per_thread + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + +@pytest.mark.parametrize( + "block_size,num_tokens,topk", + list( + itertools.product( + [32, 64, 128, 256], # block_size + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens + [1, 2, 4, 8, 16, 32, 64], # topk + ) + ), +) +def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk): # For DeepSeek V3, we have 256 experts num_experts = 256 - # Test different combinations of block_size, num_tokens and topk - for block_size in [32, 64, 128, 256]: - print(f"\nTesting block_size={block_size}") - for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - for topk in [1, 2, 4, 8, 16, 32, 64]: - print( - f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}" - ) + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ] + ) - # Create random topk_ids with shape [num_tokens, topk] - topk_ids = torch.randint( - 0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - max_num_tokens_padded = topk_ids.numel() + num_experts * ( - block_size - 1 - ) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty( - (1), dtype=torch.int32, device=topk_ids.device - ) + sorted_ids_cuda = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids_cuda.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids_cuda = torch.zeros( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad_cuda = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) + sorted_ids_triton = torch.empty_like(sorted_ids_cuda) + sorted_ids_triton.fill_(topk_ids.numel()) + expert_ids_triton = torch.zeros_like(expert_ids_cuda) + num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - try: - moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - ) - except Exception as e: - print( - f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}" - ) - print(f"Error message: {str(e)}") - raise e + moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_cuda, + expert_ids_cuda, + num_tokens_post_pad_cuda, + token_cnts_buffer, + cumsum_buffer, + ) + + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids_triton, + expert_ids_triton, + num_tokens_post_pad_triton, + ) + + assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( + f"Expert IDs mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA expert_ids: {expert_ids_cuda}\n" + f"Triton expert_ids: {expert_ids_triton}" + ) + + assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( + f"Num tokens post pad mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" + f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}" + ) if __name__ == "__main__": - test_moe_align_block_size() + pytest.main([__file__])