From b93ef5e56d5ea0a4ecf6f79eba422b70c33384f9 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Sat, 8 Mar 2025 14:42:16 +0800 Subject: [PATCH] Remove the vllm dependency from the moe_align function (#4164) Co-authored-by: Hongbosherlock --- .../sgl-kernel/csrc/moe/moe_align_kernel.cu | 18 ++++++++++-------- sgl-kernel/tests/test_moe_align.py | 8 +++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu index c5f37e556..83609a329 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t experts_per_warp, int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { - __shared__ int32_t shared_counts[WARP_SIZE][8]; + extern __shared__ int32_t shared_counts[]; const int warp_id = threadIdx.x / WARP_SIZE; - const int experts_per_warp = 8; const int my_expert_start = warp_id * experts_per_warp; for (int i = 0; i < experts_per_warp; ++i) { if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; + shared_counts[warp_id * experts_per_warp + i] = 0; } } @@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel( int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); } __syncthreads(); @@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel( int expert_count = 0; int warp_idx = (i - 1) / experts_per_warp; int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; } @@ -108,16 +108,18 @@ void moe_align_block_size( torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); - + TORCH_CHECK(num_experts % WARP_SIZE == 0); + int experts_per_warp = num_experts / WARP_SIZE; DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto align_kernel = moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>( + size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t); + align_kernel<<<1, 1024, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, + experts_per_warp, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 81d05ffa1..3d89c3406 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -138,18 +138,20 @@ def moe_align_block_size_triton( @pytest.mark.parametrize( - "block_size,num_tokens,topk", + "block_size,num_tokens,topk,num_experts", list( itertools.product( [32, 64, 128, 256], # block_size [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64], # topk + [64, 160, 256], # num_experts ) ), ) -def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk): +def test_moe_align_block_size_compare_implementations( + block_size, num_tokens, topk, num_experts +): # For DeepSeek V3, we have 256 experts - num_experts = 256 topk_ids = torch.stack( [