diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index 1d9504d0a..ff8be14a3 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -99,13 +99,12 @@ def moe_align_block_size_triton( sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, + tokens_cnts: torch.Tensor, + cumsum: torch.Tensor, ) -> None: numel = topk_ids.numel() grid = (num_experts,) - tokens_cnts = torch.zeros( - (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device - ) - cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) moe_align_block_size_stage1[grid]( @@ -139,11 +138,18 @@ def moe_align_block_size_triton( ) -def calculate_diff(batch_size, seq_len): - num_experts = 256 +def calculate_diff(batch_size, seq_len, num_experts): + num_experts = num_experts block_size = 128 topk = 8 + assert batch_size >= 1 + assert seq_len >= 1 + assert num_experts >= 4 + + if topk > num_experts: + topk = num_experts + topk_ids = torch.stack( [ torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] @@ -175,6 +181,13 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton = torch.zeros_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) + token_cnts_buffer_triton = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum_buffer_triton = torch.zeros( + (num_experts + 1,), dtype=torch.int32, device=topk_ids.device + ) + # compare the performance of cuda and triton implementation moe_align_block_size( topk_ids, @@ -193,14 +206,27 @@ def calculate_diff(batch_size, seq_len): sorted_ids_triton, expert_ids_triton, num_tokens_post_pad_triton, + token_cnts_buffer_triton, + cumsum_buffer_triton, ) - if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( - num_tokens_post_pad_cuda, num_tokens_post_pad_triton + sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values + sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values + + if ( + torch.allclose(expert_ids_cuda, expert_ids_triton) + and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton) + and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot) ): - print("✅ CUDA and Triton implementations match") + print( + "✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format( + batch_size * seq_len, num_experts + ) + ) else: print("❌ CUDA and Triton implementations do not match") + print("CUDA sorted ids:", sorted_ids_cuda_snapshot) + print("Triton sorted ids:", sorted_ids_triton_snapshot) print("CUDA expert_ids:", expert_ids_cuda) print("Triton expert_ids:", expert_ids_triton) print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) @@ -256,7 +282,7 @@ def benchmark(batch_size, seq_len, provider): ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty( + expert_ids = torch.zeros( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) @@ -267,34 +293,37 @@ def benchmark(batch_size, seq_len, provider): num_experts + 1, dtype=torch.int32, device=topk_ids.device ) - quantiles = [0.5, 0.2, 0.8] - if provider == "cuda": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - token_cnts_buffer, - cumsum_buffer, - ), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), - quantiles=quantiles, + # Warm up + api_func = ( + moe_align_block_size if provider == "cuda" else moe_align_block_size_triton + ) + for _ in range(10): + api_func( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer.clone(), + cumsum_buffer.clone(), ) + torch.cuda.synchronize() + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: api_func( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer.clone(), + cumsum_buffer.clone(), + ), + quantiles=quantiles, + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms @@ -306,8 +335,22 @@ if __name__ == "__main__": default="./configs/benchmark_ops/moe_align_blocks/", help="Path to save moe align benchmark results", ) + parser.add_argument( + "--verify", + action="store_true", + help="verify kernel", + ) + args = parser.parse_args() - calculate_diff(batch_size=4, seq_len=1024) + if args.verify: + num_experts_range = [2**i for i in range(3, 9)] - benchmark.run(print_data=True) + configs = list( + itertools.product(batch_size_range, seq_length_range, num_experts_range) + ) + + for bs, seq, num_experts in configs: + calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts) + else: + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index b9f79fce6..953db4f48 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "torch"] +requires = ["setuptools>=61.0", "wheel", "torch<=2.5.1"] build-backend = "setuptools.build_meta" [project] diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 473aae6f5..592c5bff9 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -16,77 +16,284 @@ limitations under the License. #include #include #include +#include #include #include #include "utils.h" -#define WARP_SIZE 32 +#define MAX_NUM_EXPERTS 256 +#define EXPERTS_PER_WARP ((MAX_NUM_EXPERTS) / (WARP_SIZE)) -template -__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, size_t numel) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; +#define FRAGS_PER_BLOCK 4 - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; +#define FRAG_SIZE_M 16 +#define FRAG_SIZE_N 16 + +#ifndef USE_ROCM +#define kWarpsToLoad 2 +#else +#define kWarpsToLoad 1 +#endif + +#define kElementsPerAccess 4 +#define kElementsPerThr 16 + +#define SGLANG_FORCE_INLINE_DEVICE_FUNC static __forceinline__ __attribute__((always_inline)) __device__ + +namespace cg = cooperative_groups; + +SGLANG_FORCE_INLINE_DEVICE_FUNC void store_global_cumsum(int* cumsum /*dest*/, int* total_tokens_post_pad /*dest*/, + const int32_t* local_offsets, const int& tid, + const int& num_experts, cg::grid_group& grid) { + int active_threads = CEILDIV(num_experts + 1, kElementsPerThr); + if (tid < active_threads - 1) { + for (int i = tid * kElementsPerThr; i < (tid + 1) * kElementsPerThr; i += kElementsPerAccess) { + *(int4*)(cumsum + i) = *(int4*)(local_offsets + i); + } } + + if (tid == active_threads - 1) { +#pragma unroll + for (int i = tid * kElementsPerThr; i < num_experts + 1; i++) { + *(cumsum + i) = *(local_offsets + i); + } + } + if (tid == active_threads) { + *total_tokens_post_pad = local_offsets[num_experts]; + } + __threadfence_system(); + grid.sync(); +} + +SGLANG_FORCE_INLINE_DEVICE_FUNC void align_global_cumsum(int32_t* local_offsets /*src_and_dest*/, + int32_t* local_offsets_buf, int* smem_ptr, const int tid, + const int32_t& block_size, const int32_t& num_experts) { + int active_threads = CEILDIV(num_experts, kElementsPerThr); + int start = tid * kElementsPerThr + 1; + int end = MIN((tid + 1) * kElementsPerThr, num_experts) + 1; + if (tid == 0) { + smem_ptr[0] = 0; + } + if (tid < active_threads) { + for (int i = start; i < end; ++i) { + smem_ptr[i] = local_offsets[i] - local_offsets[i - 1]; + } + } + __syncthreads(); + + if (tid < active_threads) { + for (int i = start; i < end; ++i) { + int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1]; + local_offsets[i] = last_val + CEILDIV(smem_ptr[i], block_size) * block_size; + } + + local_offsets_buf[tid] = local_offsets[end - 1]; + } + __syncthreads(); + + if (tid < active_threads && tid > 0) { + int offset = 0; + for (int j = 0; j < tid; ++j) { + offset += local_offsets_buf[j]; + } + + for (int i = start; i < end; ++i) { + local_offsets[i] += offset; + } + } + __syncthreads(); +} + +SGLANG_FORCE_INLINE_DEVICE_FUNC void reduce_unaligned_cumsum(int* tokens_cnts_ptr /*src_and_dest*/, int* smem_ptr, + int32_t* local_offsets, const int& tid, const int& lane_id, + const int& warp_id, const int32_t& num_experts, + cg::grid_group& grid) { + int total_fragments = CEILDIV(num_experts, FRAG_SIZE_N); + int fragments_per_block = CEILDIV(total_fragments, gridDim.x); + int fragments_per_warp = CEILDIV(fragments_per_block, FRAGS_PER_BLOCK); + + for (int i = 0; i < gridDim.x; i += FRAG_SIZE_M) { + for (int j = 0; j < fragments_per_warp; j++) { + if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) { + const int kNumThrPerRow = WARP_SIZE / FRAG_SIZE_N; + int sRow = lane_id / kNumThrPerRow; + + int sWarpColStride = kNumThrPerRow * kElementsPerAccess; + int sWarpColOff = warp_id * sWarpColStride; + int sThrColOff = lane_id % kNumThrPerRow * kElementsPerAccess; + + int sCol = sThrColOff + sWarpColOff; + + int gRow = i + sRow; + + int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N; + int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N; + int gWarpColOff_1 = warp_id % kWarpsToLoad * sWarpColStride; + + int gCol = gBlockColOff + gWarpColOff_0 + gWarpColOff_1 + sThrColOff; + + if (gRow < num_experts && gCol < num_experts) { + int4* tokens_cnts_4i_ptr = (int4*)(tokens_cnts_ptr + (gRow + 1) * num_experts + gCol); + int4* smem_4i_ptr = (int4*)(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol); + + *smem_4i_ptr = *tokens_cnts_4i_ptr; + } + } + __syncthreads(); + + if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) { + if (warp_id % kWarpsToLoad == 0) { + for (int k = 0; k < FRAG_SIZE_M; k += (WARP_SIZE / FRAG_SIZE_N)) { + int sRow = lane_id / FRAG_SIZE_N + k; + int sThrColOff = lane_id % FRAG_SIZE_N; + int sCol = sThrColOff + (warp_id / kWarpsToLoad) * FRAG_SIZE_N; + + int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N; + int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N; + int gCol = gBlockColOff + gWarpColOff_0 + sThrColOff; + if (gCol < num_experts) { + atomicAdd(local_offsets + gCol + 1, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol)); + // atomicAdd(tokens_cnts_ptr + gCol, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol)); + } + } + } + } + __syncthreads(); + + } // end of j + } // end of i + + if (threadIdx.x < num_experts) { + atomicAdd(tokens_cnts_ptr + threadIdx.x, *(local_offsets + threadIdx.x + 1)); + } + + __threadfence_system(); + grid.sync(); + + if (tid < num_experts) { + *(local_offsets + tid + 1) = *(tokens_cnts_ptr + tid); + } + __syncthreads(); +} + +SGLANG_FORCE_INLINE_DEVICE_FUNC void parallel_unaligned_local_cumsum( + const int& tid, int* tokens_cnts_ptr /*dest*/, int32_t* local_offsets /*dest*/, int32_t* local_offsets_buf, + const int32_t (*shared_counts)[EXPERTS_PER_WARP] /*src*/, const int& experts_per_warp, const int32_t& num_experts, + cg::grid_group& grid) { + int active_threads = CEILDIV(num_experts, kElementsPerThr); + + if (threadIdx.x == 0) { + local_offsets[0] = 0; + } + if (threadIdx.x < active_threads) { + for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1; + ++i) { + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + + int expert_count = shared_counts[warp_idx][expert_offset]; + + int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1]; + local_offsets[i] = last_val + expert_count; + } + + local_offsets_buf[threadIdx.x] = local_offsets[MIN((threadIdx.x + 1) * kElementsPerThr, num_experts)]; + } + __syncthreads(); + + if (threadIdx.x < active_threads && threadIdx.x > 0) { + int offset = 0; + for (int j = 0; j < threadIdx.x; ++j) { + offset += local_offsets_buf[j]; + } + + for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1; + ++i) { + local_offsets[i] += offset; + } + } + __syncthreads(); + + if (tid < num_experts) { + *(tokens_cnts_ptr + tid) = 0; + } + if (threadIdx.x < num_experts) { + *(tokens_cnts_ptr + (blockIdx.x + 1) * num_experts + threadIdx.x) = *(local_offsets + threadIdx.x + 1); + *(local_offsets + threadIdx.x + 1) = 0; + } else if (threadIdx.x < MAX_NUM_EXPERTS) { + *(local_offsets + threadIdx.x + 1) = 0; + } + __threadfence_system(); + grid.sync(); } template __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { - __shared__ int32_t shared_counts[WARP_SIZE][8]; + int32_t block_size, size_t numel, int32_t* __restrict__ tokens_cnts, + int32_t* __restrict__ cumsum, const int tokens_per_block, + const int tokens_per_thread, const int K) { + __shared__ int32_t smem[FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK]; + int32_t(*shared_counts)[EXPERTS_PER_WARP] = (int32_t(*)[EXPERTS_PER_WARP]) & smem[0]; + __shared__ int32_t local_offsets[MAX_NUM_EXPERTS + 1]; + __shared__ int32_t local_offsets_buf[CEILDIV(MAX_NUM_EXPERTS, kElementsPerThr)]; + + const int tid = threadIdx.x + blockDim.x * blockIdx.x; const int warp_id = threadIdx.x / WARP_SIZE; - const int experts_per_warp = 8; - const int my_expert_start = warp_id * experts_per_warp; + const int lane_id = threadIdx.x % WARP_SIZE; + const int experts_per_warp = EXPERTS_PER_WARP; - for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; + int* tokens_cnts_ptr = &(tokens_cnts[0]); + int* smem_ptr = &(smem[0]); + + cg::grid_group grid = cg::this_grid(); + + if (threadIdx.x < FRAG_SIZE_M * FRAG_SIZE_N) { + for (int i = 0; i < FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK; i += FRAG_SIZE_M * FRAG_SIZE_N) { + smem[threadIdx.x + i] = 0; } } - __syncthreads(); - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + const size_t start_idx = tokens_per_block * blockIdx.x + tokens_per_thread * threadIdx.x; + const size_t end_idx = start_idx + tokens_per_thread; - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int expert_id = topk_ids[i]; - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = 0; - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; - - cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; + if (threadIdx.x * tokens_per_thread < tokens_per_block) { + for (int i = start_idx; i < MIN(numel, end_idx); ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); } - *total_tokens_post_pad = cumsum[num_experts]; } - __syncthreads(); - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + parallel_unaligned_local_cumsum(tid, tokens_cnts_ptr /*dest*/, local_offsets, local_offsets_buf, shared_counts, + experts_per_warp, num_experts, grid); + + reduce_unaligned_cumsum(tokens_cnts_ptr /*src_and_dest*/, smem_ptr, local_offsets, tid, lane_id, warp_id, num_experts, + grid); + + align_global_cumsum(local_offsets /*src_and_dest*/, local_offsets_buf, smem_ptr, tid, block_size, num_experts); + + store_global_cumsum(cumsum /*dest*/, total_tokens_post_pad /*dest*/, local_offsets /*src*/, tid, num_experts, grid); + + if (tid < num_experts) { + for (int i = local_offsets[tid]; i < local_offsets[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid; + } + } + __syncthreads(); + + if (threadIdx.x * tokens_per_thread < tokens_per_block) { + for (int i = start_idx; i < MIN(numel, end_idx); ++i) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&cumsum[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; } } } @@ -95,22 +302,29 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto align_kernel = moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + auto kernel = moe_align_block_size_kernel; const int block_threads = 256; - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - auto sort_kernel = count_and_sort_expert_tokens_kernel; - sort_kernel<<>>(topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + const int num_blocks = MIN(CEILDIV(topk_ids.sizes()[0], block_threads), num_experts); + + scalar_t* topk_ids_ptr = topk_ids.data_ptr(); + int32_t* sorted_token_ids_ptr = sorted_token_ids.data_ptr(); + int32_t* experts_ids_ptr = experts_ids.data_ptr(); + int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data_ptr(); + size_t num_tokens = topk_ids.numel(); + int32_t* token_cnts_buffer_ptr = token_cnts_buffer.data_ptr(); + int32_t* cumsum_buffer_ptr = cumsum_buffer.data_ptr(); + int tokens_per_block = CEILDIV(topk_ids.sizes()[0], num_blocks) * topk_ids.sizes()[1]; + int tokens_per_thread = CEILDIV(tokens_per_block, block_threads); + int K = topk_ids.sizes()[1]; + + void* kernelArgs[] = {&topk_ids_ptr, &sorted_token_ids_ptr, &experts_ids_ptr, &num_tokens_post_pad_ptr, + &num_experts, &block_size, &num_tokens, &token_cnts_buffer_ptr, + &cumsum_buffer_ptr, &tokens_per_block, &tokens_per_thread, &K}; + + cudaLaunchCooperativeKernel((void*)kernel, num_blocks, block_threads, kernelArgs); }); } diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index a342dee10..ab7806eba 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -49,6 +49,17 @@ struct cuda_error : public std::runtime_error { } \ } while (0) +#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) +template +void check(T result, char const* const func, const char* const file, int const line) { + if (result) { + fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, static_cast(result), + cudaGetErrorString(result), func); + cudaDeviceReset(); + exit(EXIT_FAILURE); + } +} + #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT(x) \ @@ -95,3 +106,22 @@ inline int getSMVersion() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y)-1) / (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize // 64 +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include + +static __inline__ __host__ __device__ hipError_t cudaLaunchCooperativeKernel(const void* f, dim3 gridDim, + dim3 blockDimX, void** kernelParams) { + return hipLaunchCooperativeKernel(f, gridDim, blockDimX, kernelParams, 0, hipStreamDefault); +} + +#endif diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 81d05ffa1..b1790851b 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to num_tokens_post_pad_cuda = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.empty( + token_cnts_buffer = torch.zeros( (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device, ) - cumsum_buffer = torch.empty( + cumsum_buffer = torch.zeros( num_experts + 1, dtype=torch.int32, device=topk_ids.device )