From 8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 7 Jun 2025 17:47:36 +0800 Subject: [PATCH] reduce torch.zeros overhead in moe align block size kernel (#6369) --- .../layers/moe/fused_moe_triton/fused_moe.py | 64 +++++++++++++++++-- sgl-kernel/csrc/moe/moe_align_kernel.cu | 2 - 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index df4a490e4..935de6e57 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -30,6 +30,7 @@ from sglang.srt.utils import ( is_cuda, is_hip, log_info_on_rank0, + next_power_of_2, ) _is_hip = is_hip() @@ -650,6 +651,61 @@ def moe_align_block_size_triton( ) +@triton.jit +def init_sorted_ids_and_cumsum_buffer_kernel( + sorted_ids_ptr, + cumsum_buffer_ptr, + max_num_tokens_padded, + topk_ids_numel, + num_experts: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ALIGNED_NUM_EXPERTS_P1: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE) + + if pid < sorted_ids_blocks: + mask = offsets < max_num_tokens_padded + tl.store( + sorted_ids_ptr + offsets, + tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32), + mask=mask, + ) + elif pid == sorted_ids_blocks: + offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1) + mask_e = offset_e < num_experts + 1 + tl.store( + cumsum_buffer_ptr + offset_e, + tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32), + mask=mask_e, + ) + + +def init_sorted_ids_and_cumsum_buffer( + max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda" +): + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device) + cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) + + BLOCK_SIZE = 1024 + sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE) + grid = (sorted_ids_blocks + 1,) + + init_sorted_ids_and_cumsum_buffer_kernel[grid]( + sorted_ids, + cumsum_buffer, + max_num_tokens_padded, + topk_ids_numel, + num_experts, + BLOCK_SIZE, + next_power_of_2(num_experts + 1), + ) + + return sorted_ids, cumsum_buffer + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -691,10 +747,9 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer( + max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device ) - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device @@ -715,9 +770,6 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device, ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) sgl_moe_align_block_size( topk_ids, diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index d44eff5c1..e3abb8849 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -197,8 +197,6 @@ void moe_align_block_size( size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - cumsum_buffer.zero_(); - align_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(),