From 18bb216c28e6d92c95f700b6aa71dd67cab3335a Mon Sep 17 00:00:00 2001 From: Chayenne Date: Fri, 28 Feb 2025 23:57:17 -0800 Subject: [PATCH] Revert "[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x)" (#3982) --- .../benchmark_deepseekv3_moe_align_blocks.py | 121 +++---- sgl-kernel/pyproject.toml | 2 +- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 318 +++--------------- sgl-kernel/src/sgl-kernel/include/utils.h | 30 -- sgl-kernel/tests/test_moe_align.py | 4 +- 5 files changed, 94 insertions(+), 381 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index ff8be14a3..1d9504d0a 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -99,12 +99,13 @@ def moe_align_block_size_triton( sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, - tokens_cnts: torch.Tensor, - cumsum: torch.Tensor, ) -> None: numel = topk_ids.numel() grid = (num_experts,) - + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) tokens_per_thread = ceil_div(numel, num_experts) moe_align_block_size_stage1[grid]( @@ -138,18 +139,11 @@ def moe_align_block_size_triton( ) -def calculate_diff(batch_size, seq_len, num_experts): - num_experts = num_experts +def calculate_diff(batch_size, seq_len): + num_experts = 256 block_size = 128 topk = 8 - assert batch_size >= 1 - assert seq_len >= 1 - assert num_experts >= 4 - - if topk > num_experts: - topk = num_experts - topk_ids = torch.stack( [ torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] @@ -181,13 +175,6 @@ def calculate_diff(batch_size, seq_len, num_experts): expert_ids_triton = torch.zeros_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - token_cnts_buffer_triton = torch.zeros( - (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device - ) - cumsum_buffer_triton = torch.zeros( - (num_experts + 1,), dtype=torch.int32, device=topk_ids.device - ) - # compare the performance of cuda and triton implementation moe_align_block_size( topk_ids, @@ -206,27 +193,14 @@ def calculate_diff(batch_size, seq_len, num_experts): sorted_ids_triton, expert_ids_triton, num_tokens_post_pad_triton, - token_cnts_buffer_triton, - cumsum_buffer_triton, ) - sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values - sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values - - if ( - torch.allclose(expert_ids_cuda, expert_ids_triton) - and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton) - and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot) + if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( + num_tokens_post_pad_cuda, num_tokens_post_pad_triton ): - print( - "✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format( - batch_size * seq_len, num_experts - ) - ) + print("✅ CUDA and Triton implementations match") else: print("❌ CUDA and Triton implementations do not match") - print("CUDA sorted ids:", sorted_ids_cuda_snapshot) - print("Triton sorted ids:", sorted_ids_triton_snapshot) print("CUDA expert_ids:", expert_ids_cuda) print("Triton expert_ids:", expert_ids_triton) print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) @@ -282,7 +256,7 @@ def benchmark(batch_size, seq_len, provider): ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.zeros( + expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) @@ -293,37 +267,34 @@ def benchmark(batch_size, seq_len, provider): num_experts + 1, dtype=torch.int32, device=topk_ids.device ) - # Warm up - api_func = ( - moe_align_block_size if provider == "cuda" else moe_align_block_size_triton - ) - for _ in range(10): - api_func( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer.clone(), - cumsum_buffer.clone(), - ) - torch.cuda.synchronize() - quantiles = [0.5, 0.2, 0.8] - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: api_func( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer.clone(), - cumsum_buffer.clone(), - ), - quantiles=quantiles, - ) + if provider == "cuda": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + token_cnts_buffer, + cumsum_buffer, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms @@ -335,22 +306,8 @@ if __name__ == "__main__": default="./configs/benchmark_ops/moe_align_blocks/", help="Path to save moe align benchmark results", ) - parser.add_argument( - "--verify", - action="store_true", - help="verify kernel", - ) - args = parser.parse_args() - if args.verify: - num_experts_range = [2**i for i in range(3, 9)] + calculate_diff(batch_size=4, seq_len=1024) - configs = list( - itertools.product(batch_size_range, seq_length_range, num_experts_range) - ) - - for bs, seq, num_experts in configs: - calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts) - else: - benchmark.run(print_data=True, save_path=args.save_path) + benchmark.run(print_data=True) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 953db4f48..b9f79fce6 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "torch<=2.5.1"] +requires = ["setuptools>=61.0", "wheel", "torch"] build-backend = "setuptools.build_meta" [project] diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 592c5bff9..473aae6f5 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -16,284 +16,77 @@ limitations under the License. #include #include #include -#include #include #include #include "utils.h" -#define MAX_NUM_EXPERTS 256 -#define EXPERTS_PER_WARP ((MAX_NUM_EXPERTS) / (WARP_SIZE)) +#define WARP_SIZE 32 -#define FRAGS_PER_BLOCK 4 +template +__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; -#define FRAG_SIZE_M 16 -#define FRAG_SIZE_N 16 - -#ifndef USE_ROCM -#define kWarpsToLoad 2 -#else -#define kWarpsToLoad 1 -#endif - -#define kElementsPerAccess 4 -#define kElementsPerThr 16 - -#define SGLANG_FORCE_INLINE_DEVICE_FUNC static __forceinline__ __attribute__((always_inline)) __device__ - -namespace cg = cooperative_groups; - -SGLANG_FORCE_INLINE_DEVICE_FUNC void store_global_cumsum(int* cumsum /*dest*/, int* total_tokens_post_pad /*dest*/, - const int32_t* local_offsets, const int& tid, - const int& num_experts, cg::grid_group& grid) { - int active_threads = CEILDIV(num_experts + 1, kElementsPerThr); - if (tid < active_threads - 1) { - for (int i = tid * kElementsPerThr; i < (tid + 1) * kElementsPerThr; i += kElementsPerAccess) { - *(int4*)(cumsum + i) = *(int4*)(local_offsets + i); - } + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; } - - if (tid == active_threads - 1) { -#pragma unroll - for (int i = tid * kElementsPerThr; i < num_experts + 1; i++) { - *(cumsum + i) = *(local_offsets + i); - } - } - if (tid == active_threads) { - *total_tokens_post_pad = local_offsets[num_experts]; - } - __threadfence_system(); - grid.sync(); -} - -SGLANG_FORCE_INLINE_DEVICE_FUNC void align_global_cumsum(int32_t* local_offsets /*src_and_dest*/, - int32_t* local_offsets_buf, int* smem_ptr, const int tid, - const int32_t& block_size, const int32_t& num_experts) { - int active_threads = CEILDIV(num_experts, kElementsPerThr); - int start = tid * kElementsPerThr + 1; - int end = MIN((tid + 1) * kElementsPerThr, num_experts) + 1; - if (tid == 0) { - smem_ptr[0] = 0; - } - if (tid < active_threads) { - for (int i = start; i < end; ++i) { - smem_ptr[i] = local_offsets[i] - local_offsets[i - 1]; - } - } - __syncthreads(); - - if (tid < active_threads) { - for (int i = start; i < end; ++i) { - int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1]; - local_offsets[i] = last_val + CEILDIV(smem_ptr[i], block_size) * block_size; - } - - local_offsets_buf[tid] = local_offsets[end - 1]; - } - __syncthreads(); - - if (tid < active_threads && tid > 0) { - int offset = 0; - for (int j = 0; j < tid; ++j) { - offset += local_offsets_buf[j]; - } - - for (int i = start; i < end; ++i) { - local_offsets[i] += offset; - } - } - __syncthreads(); -} - -SGLANG_FORCE_INLINE_DEVICE_FUNC void reduce_unaligned_cumsum(int* tokens_cnts_ptr /*src_and_dest*/, int* smem_ptr, - int32_t* local_offsets, const int& tid, const int& lane_id, - const int& warp_id, const int32_t& num_experts, - cg::grid_group& grid) { - int total_fragments = CEILDIV(num_experts, FRAG_SIZE_N); - int fragments_per_block = CEILDIV(total_fragments, gridDim.x); - int fragments_per_warp = CEILDIV(fragments_per_block, FRAGS_PER_BLOCK); - - for (int i = 0; i < gridDim.x; i += FRAG_SIZE_M) { - for (int j = 0; j < fragments_per_warp; j++) { - if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) { - const int kNumThrPerRow = WARP_SIZE / FRAG_SIZE_N; - int sRow = lane_id / kNumThrPerRow; - - int sWarpColStride = kNumThrPerRow * kElementsPerAccess; - int sWarpColOff = warp_id * sWarpColStride; - int sThrColOff = lane_id % kNumThrPerRow * kElementsPerAccess; - - int sCol = sThrColOff + sWarpColOff; - - int gRow = i + sRow; - - int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N; - int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N; - int gWarpColOff_1 = warp_id % kWarpsToLoad * sWarpColStride; - - int gCol = gBlockColOff + gWarpColOff_0 + gWarpColOff_1 + sThrColOff; - - if (gRow < num_experts && gCol < num_experts) { - int4* tokens_cnts_4i_ptr = (int4*)(tokens_cnts_ptr + (gRow + 1) * num_experts + gCol); - int4* smem_4i_ptr = (int4*)(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol); - - *smem_4i_ptr = *tokens_cnts_4i_ptr; - } - } - __syncthreads(); - - if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) { - if (warp_id % kWarpsToLoad == 0) { - for (int k = 0; k < FRAG_SIZE_M; k += (WARP_SIZE / FRAG_SIZE_N)) { - int sRow = lane_id / FRAG_SIZE_N + k; - int sThrColOff = lane_id % FRAG_SIZE_N; - int sCol = sThrColOff + (warp_id / kWarpsToLoad) * FRAG_SIZE_N; - - int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N; - int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N; - int gCol = gBlockColOff + gWarpColOff_0 + sThrColOff; - if (gCol < num_experts) { - atomicAdd(local_offsets + gCol + 1, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol)); - // atomicAdd(tokens_cnts_ptr + gCol, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol)); - } - } - } - } - __syncthreads(); - - } // end of j - } // end of i - - if (threadIdx.x < num_experts) { - atomicAdd(tokens_cnts_ptr + threadIdx.x, *(local_offsets + threadIdx.x + 1)); - } - - __threadfence_system(); - grid.sync(); - - if (tid < num_experts) { - *(local_offsets + tid + 1) = *(tokens_cnts_ptr + tid); - } - __syncthreads(); -} - -SGLANG_FORCE_INLINE_DEVICE_FUNC void parallel_unaligned_local_cumsum( - const int& tid, int* tokens_cnts_ptr /*dest*/, int32_t* local_offsets /*dest*/, int32_t* local_offsets_buf, - const int32_t (*shared_counts)[EXPERTS_PER_WARP] /*src*/, const int& experts_per_warp, const int32_t& num_experts, - cg::grid_group& grid) { - int active_threads = CEILDIV(num_experts, kElementsPerThr); - - if (threadIdx.x == 0) { - local_offsets[0] = 0; - } - if (threadIdx.x < active_threads) { - for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1; - ++i) { - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - - int expert_count = shared_counts[warp_idx][expert_offset]; - - int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1]; - local_offsets[i] = last_val + expert_count; - } - - local_offsets_buf[threadIdx.x] = local_offsets[MIN((threadIdx.x + 1) * kElementsPerThr, num_experts)]; - } - __syncthreads(); - - if (threadIdx.x < active_threads && threadIdx.x > 0) { - int offset = 0; - for (int j = 0; j < threadIdx.x; ++j) { - offset += local_offsets_buf[j]; - } - - for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1; - ++i) { - local_offsets[i] += offset; - } - } - __syncthreads(); - - if (tid < num_experts) { - *(tokens_cnts_ptr + tid) = 0; - } - if (threadIdx.x < num_experts) { - *(tokens_cnts_ptr + (blockIdx.x + 1) * num_experts + threadIdx.x) = *(local_offsets + threadIdx.x + 1); - *(local_offsets + threadIdx.x + 1) = 0; - } else if (threadIdx.x < MAX_NUM_EXPERTS) { - *(local_offsets + threadIdx.x + 1) = 0; - } - __threadfence_system(); - grid.sync(); } template __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* __restrict__ tokens_cnts, - int32_t* __restrict__ cumsum, const int tokens_per_block, - const int tokens_per_thread, const int K) { - __shared__ int32_t smem[FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK]; - int32_t(*shared_counts)[EXPERTS_PER_WARP] = (int32_t(*)[EXPERTS_PER_WARP]) & smem[0]; + int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { + __shared__ int32_t shared_counts[WARP_SIZE][8]; - __shared__ int32_t local_offsets[MAX_NUM_EXPERTS + 1]; - __shared__ int32_t local_offsets_buf[CEILDIV(MAX_NUM_EXPERTS, kElementsPerThr)]; - - const int tid = threadIdx.x + blockDim.x * blockIdx.x; const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int experts_per_warp = EXPERTS_PER_WARP; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; - int* tokens_cnts_ptr = &(tokens_cnts[0]); - int* smem_ptr = &(smem[0]); - - cg::grid_group grid = cg::this_grid(); - - if (threadIdx.x < FRAG_SIZE_M * FRAG_SIZE_N) { - for (int i = 0; i < FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK; i += FRAG_SIZE_M * FRAG_SIZE_N) { - smem[threadIdx.x + i] = 0; + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; } } + __syncthreads(); - const size_t start_idx = tokens_per_block * blockIdx.x + tokens_per_thread * threadIdx.x; - const size_t end_idx = start_idx + tokens_per_thread; + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; - if (threadIdx.x * tokens_per_thread < tokens_per_block) { - for (int i = start_idx; i < MIN(numel, end_idx); ++i) { - int expert_id = topk_ids[i]; - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); - } + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); } + __syncthreads(); - parallel_unaligned_local_cumsum(tid, tokens_cnts_ptr /*dest*/, local_offsets, local_offsets_buf, shared_counts, - experts_per_warp, num_experts, grid); + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; - reduce_unaligned_cumsum(tokens_cnts_ptr /*src_and_dest*/, smem_ptr, local_offsets, tid, lane_id, warp_id, num_experts, - grid); - - align_global_cumsum(local_offsets /*src_and_dest*/, local_offsets_buf, smem_ptr, tid, block_size, num_experts); - - store_global_cumsum(cumsum /*dest*/, total_tokens_post_pad /*dest*/, local_offsets /*src*/, tid, num_experts, grid); - - if (tid < num_experts) { - for (int i = local_offsets[tid]; i < local_offsets[tid + 1]; i += block_size) { - expert_ids[i / block_size] = tid; + cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; } + *total_tokens_post_pad = cumsum[num_experts]; } + __syncthreads(); - if (threadIdx.x * tokens_per_thread < tokens_per_block) { - for (int i = start_idx; i < MIN(numel, end_idx); ++i) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; } } } @@ -302,29 +95,22 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto kernel = moe_align_block_size_kernel; + auto align_kernel = moe_align_block_size_kernel; + align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); const int block_threads = 256; + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); - const int num_blocks = MIN(CEILDIV(topk_ids.sizes()[0], block_threads), num_experts); - - scalar_t* topk_ids_ptr = topk_ids.data_ptr(); - int32_t* sorted_token_ids_ptr = sorted_token_ids.data_ptr(); - int32_t* experts_ids_ptr = experts_ids.data_ptr(); - int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data_ptr(); - size_t num_tokens = topk_ids.numel(); - int32_t* token_cnts_buffer_ptr = token_cnts_buffer.data_ptr(); - int32_t* cumsum_buffer_ptr = cumsum_buffer.data_ptr(); - int tokens_per_block = CEILDIV(topk_ids.sizes()[0], num_blocks) * topk_ids.sizes()[1]; - int tokens_per_thread = CEILDIV(tokens_per_block, block_threads); - int K = topk_ids.sizes()[1]; - - void* kernelArgs[] = {&topk_ids_ptr, &sorted_token_ids_ptr, &experts_ids_ptr, &num_tokens_post_pad_ptr, - &num_experts, &block_size, &num_tokens, &token_cnts_buffer_ptr, - &cumsum_buffer_ptr, &tokens_per_block, &tokens_per_thread, &K}; - - cudaLaunchCooperativeKernel((void*)kernel, num_blocks, block_threads, kernelArgs); + auto sort_kernel = count_and_sort_expert_tokens_kernel; + sort_kernel<<>>(topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), topk_ids.numel()); }); } diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index ab7806eba..a342dee10 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -49,17 +49,6 @@ struct cuda_error : public std::runtime_error { } \ } while (0) -#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) -template -void check(T result, char const* const func, const char* const file, int const line) { - if (result) { - fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, static_cast(result), - cudaGetErrorString(result), func); - cudaDeviceReset(); - exit(EXIT_FAILURE); - } -} - #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT(x) \ @@ -106,22 +95,3 @@ inline int getSMVersion() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y)-1) / (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) - -#ifndef USE_ROCM -#define WARP_SIZE 32 -#else -#define WARP_SIZE warpSize // 64 -#endif - -#if defined(__HIP_PLATFORM_AMD__) - -#include -#include - -static __inline__ __host__ __device__ hipError_t cudaLaunchCooperativeKernel(const void* f, dim3 gridDim, - dim3 blockDimX, void** kernelParams) { - return hipLaunchCooperativeKernel(f, gridDim, blockDimX, kernelParams, 0, hipStreamDefault); -} - -#endif diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index b1790851b..81d05ffa1 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to num_tokens_post_pad_cuda = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.zeros( + token_cnts_buffer = torch.empty( (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device, ) - cumsum_buffer = torch.zeros( + cumsum_buffer = torch.empty( num_experts + 1, dtype=torch.int32, device=topk_ids.device )