diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 2d0061af8..ad80b0c75 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -21,16 +21,10 @@ limitations under the License. #include "utils.h" -template -class alignas(Alignment) AlignedArray { - public: - T data[N]; -}; - #define WARP_SIZE 32 #define VEC_SIZE 4 -using Vec = AlignedArray; +using Vec = int4; template __global__ void count_and_sort_expert_tokens_kernel( @@ -55,73 +49,119 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t padded_num_experts, - int32_t experts_per_warp, int32_t block_size, size_t numel, int32_t* __restrict__ cumsum, - bool pad_sorted_token_ids) { - extern __shared__ int32_t shared_counts[]; - - const int warp_id = threadIdx.x / WARP_SIZE; - const int my_expert_start = warp_id * experts_per_warp; - - for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < padded_num_experts) { - shared_counts[warp_id * experts_per_warp + i] = 0; - } - } - - __syncthreads(); + bool pad_sorted_token_ids, + const int32_t scan_size) { + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; const size_t tid = threadIdx.x; const size_t stride = blockDim.x; + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i]; - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); + atomicAdd(&shared_counts[expert_id], 1); } __syncthreads(); - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = 0; - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } - cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; } __syncthreads(); - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + + __syncthreads(); + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 1; } if (pad_sorted_token_ids) { - int32_t fill_val = static_cast(numel); - int32_t total = *total_tokens_post_pad; - Vec fill_vec; -#pragma unroll - for (int i = 0; i < VEC_SIZE; ++i) { - fill_vec.data[i] = fill_val; - } - - int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; Vec* out_ptr = reinterpret_cast(sorted_token_ids); - - for (int32_t idx = tid; idx < total_vec_count; idx += stride) { - out_ptr[idx] = fill_vec; + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; } } } @@ -179,20 +219,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } if (pad_sorted_token_ids) { - int32_t fill_val = static_cast(numel); - int32_t total = *total_tokens_post_pad; - Vec fill_vec; -#pragma unroll - for (int i = 0; i < VEC_SIZE; ++i) { - fill_vec.data[i] = fill_val; - } - - int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; Vec* out_ptr = reinterpret_cast(sorted_token_ids); - - for (int32_t idx = tid; idx < total_vec_count; idx += stride) { - out_ptr[idx] = fill_vec; + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; } } @@ -245,8 +277,8 @@ void moe_align_block_size( } else { auto align_kernel = moe_align_block_size_kernel; - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); + const size_t scan_size = next_pow2(num_experts); + const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size) * sizeof(int32_t); align_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), @@ -254,12 +286,11 @@ void moe_align_block_size( experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, - padded_num_experts, - experts_per_warp, block_size, topk_ids.numel(), cumsum_buffer.data_ptr(), - pad_sorted_token_ids); + pad_sorted_token_ids, + scan_size); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index bce2de24e..1054dbc52 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -363,3 +363,9 @@ inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = } return tensor_padded; } + +// Get the next power of 2 of a number +inline uint32_t next_pow2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +}