Optimize moe_sum_reduce_kernel (#9477)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
Yuan Luo
2025-09-07 09:16:18 +08:00
committed by GitHub
parent f3b6760213
commit cb3918a091
2 changed files with 47 additions and 41 deletions

View File

@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
mask_token = offs_token < token_num
mask_dim = offs_dim < hidden_dim
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tile = tl.load(
base_ptrs + i * input_stride_1,
mask=mask_token[:, None] & mask_dim[None, :],
other=0.0,
)
accumulator += tile.to(tl.float32)
accumulator *= routed_scaling_factor
# -------- Write back --------
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
tl.store(
store_ptrs,
accumulator.to(input_ptr.dtype.element_ty),
mask=mask_token[:, None] & mask_dim[None, :],
)
def moe_sum_reduce_triton(
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 8
num_warps = 16
grid = (
triton.cdiv(token_num, BLOCK_M),