From cb3918a09127a1da8cc0976f86e7425285a1dca6 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Sun, 7 Sep 2025 09:16:18 +0800 Subject: [PATCH] Optimize moe_sum_reduce_kernel (#9477) Co-authored-by: luoyuan.luo Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- .../fused_moe_triton/benchmark_sum_scale.py | 45 ++++++++++--------- .../fused_moe_triton_kernels.py | 43 +++++++++--------- 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py index 13ff61744..979d2bbd1 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py @@ -4,7 +4,6 @@ import triton.language as tl from triton.testing import do_bench -# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py @triton.jit def _moe_sum_reduce_kernel( input_ptr, @@ -29,31 +28,35 @@ def _moe_sum_reduce_kernel( token_block_id = tl.program_id(0) dim_block_id = tl.program_id(1) - token_start = token_block_id * BLOCK_M - token_end = min((token_block_id + 1) * BLOCK_M, token_num) + offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) - dim_start = dim_block_id * BLOCK_DIM - dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + mask_token = offs_token < token_num + mask_dim = offs_dim < hidden_dim - offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :] - for token_index in range(token_start, token_end): - accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) - input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): - tmp = tl.load( - input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 - ) - accumulator += tmp - accumulator = accumulator * routed_scaling_factor - store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim - tl.store( - store_t_ptr, - accumulator.to(input_ptr.dtype.element_ty), - mask=offs_dim < dim_end, + accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32) + + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tile = tl.load( + base_ptrs + i * input_stride_1, + mask=mask_token[:, None] & mask_dim[None, :], + other=0.0, ) + accumulator += tile.to(tl.float32) + accumulator *= routed_scaling_factor + + # -------- Write back -------- + store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :] + tl.store( + store_ptrs, + accumulator.to(input_ptr.dtype.element_ty), + mask=mask_token[:, None] & mask_dim[None, :], + ) +# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py def moe_sum_reduce( input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float ): @@ -66,7 +69,7 @@ def moe_sum_reduce( BLOCK_M = 1 BLOCK_DIM = 2048 NUM_STAGE = 1 - num_warps = 8 + num_warps = 16 grid = ( triton.cdiv(token_num, BLOCK_M), diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 94f356e28..6a7229a9b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel( token_block_id = tl.program_id(0) dim_block_id = tl.program_id(1) - token_start = token_block_id * BLOCK_M - token_end = min((token_block_id + 1) * BLOCK_M, token_num) + offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) - dim_start = dim_block_id * BLOCK_DIM - dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + mask_token = offs_token < token_num + mask_dim = offs_dim < hidden_dim - offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :] - for token_index in range(token_start, token_end): - accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) - input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): - tmp = tl.load( - input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 - ) - accumulator += tmp - accumulator = accumulator * routed_scaling_factor - store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim - tl.store( - store_t_ptr, - accumulator.to(input_ptr.dtype.element_ty), - mask=offs_dim < dim_end, + accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32) + + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tile = tl.load( + base_ptrs + i * input_stride_1, + mask=mask_token[:, None] & mask_dim[None, :], + other=0.0, ) + accumulator += tile.to(tl.float32) + accumulator *= routed_scaling_factor + + # -------- Write back -------- + store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :] + tl.store( + store_ptrs, + accumulator.to(input_ptr.dtype.element_ty), + mask=mask_token[:, None] & mask_dim[None, :], + ) def moe_sum_reduce_triton( @@ -772,7 +775,7 @@ def moe_sum_reduce_triton( BLOCK_M = 1 BLOCK_DIM = 2048 NUM_STAGE = 1 - num_warps = 8 + num_warps = 16 grid = ( triton.cdiv(token_num, BLOCK_M),