From ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 6 Jan 2025 00:28:22 +0800 Subject: [PATCH] improve moe_align_kernel for deepseek v3 (#2735) --- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 84 ++++++++----------- sgl-kernel/tests/test_trt_reduce.py | 1 - 2 files changed, 33 insertions(+), 52 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index dfd28032f..31bb97fc4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -46,74 +46,61 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t template __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + int32_t block_size, size_t numel, int32_t* cumsum) { + __shared__ int32_t shared_counts[32][8]; + __shared__ int32_t local_offsets[256]; - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; } } + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + } + __syncthreads(); - // We accumulate the token counts of all experts in thread 0. if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; + + cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } __syncthreads(); - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } + local_offsets[threadIdx.x] = cumsum[threadIdx.x]; } - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ + __syncthreads(); + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } @@ -122,14 +109,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - auto kernel = moe_align_block_size_kernel; - kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), - token_cnts_buffer.data_ptr(), cumsum_buffer.data_ptr()); + kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); }); } diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py index c16d6bb6f..d6265118f 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -36,7 +36,6 @@ def multi_process_parallel( cls: Any, test_target: Any, ) -> None: - # Using ray helps debugging the error when it failed # as compared to multiprocessing. # NOTE: We need to set working_dir for distributed tests,