Optimize moe_sum_reduce_kernel (#9477)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
|
||||
token_block_id = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
token_start = token_block_id * BLOCK_M
|
||||
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
||||
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
|
||||
dim_start = dim_block_id * BLOCK_DIM
|
||||
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
||||
mask_token = offs_token < token_num
|
||||
mask_dim = offs_dim < hidden_dim
|
||||
|
||||
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
||||
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
||||
|
||||
for token_index in range(token_start, token_end):
|
||||
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
||||
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||
tmp = tl.load(
|
||||
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
||||
)
|
||||
accumulator += tmp
|
||||
accumulator = accumulator * routed_scaling_factor
|
||||
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
||||
tl.store(
|
||||
store_t_ptr,
|
||||
accumulator.to(input_ptr.dtype.element_ty),
|
||||
mask=offs_dim < dim_end,
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
||||
|
||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||
tile = tl.load(
|
||||
base_ptrs + i * input_stride_1,
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tile.to(tl.float32)
|
||||
accumulator *= routed_scaling_factor
|
||||
|
||||
# -------- Write back --------
|
||||
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
||||
tl.store(
|
||||
store_ptrs,
|
||||
accumulator.to(input_ptr.dtype.element_ty),
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
)
|
||||
|
||||
|
||||
def moe_sum_reduce_triton(
|
||||
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
|
||||
BLOCK_M = 1
|
||||
BLOCK_DIM = 2048
|
||||
NUM_STAGE = 1
|
||||
num_warps = 8
|
||||
num_warps = 16
|
||||
|
||||
grid = (
|
||||
triton.cdiv(token_num, BLOCK_M),
|
||||
|
||||
Reference in New Issue
Block a user