Optimize moe_sum_reduce_kernel (#9477)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,6 @@ import triton.language as tl
|
|||||||
from triton.testing import do_bench
|
from triton.testing import do_bench
|
||||||
|
|
||||||
|
|
||||||
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _moe_sum_reduce_kernel(
|
def _moe_sum_reduce_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
@@ -29,31 +28,35 @@ def _moe_sum_reduce_kernel(
|
|||||||
token_block_id = tl.program_id(0)
|
token_block_id = tl.program_id(0)
|
||||||
dim_block_id = tl.program_id(1)
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
token_start = token_block_id * BLOCK_M
|
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||||
|
|
||||||
dim_start = dim_block_id * BLOCK_DIM
|
mask_token = offs_token < token_num
|
||||||
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
mask_dim = offs_dim < hidden_dim
|
||||||
|
|
||||||
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
||||||
|
|
||||||
for token_index in range(token_start, token_end):
|
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
||||||
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
|
||||||
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
tile = tl.load(
|
||||||
tmp = tl.load(
|
base_ptrs + i * input_stride_1,
|
||||||
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
mask=mask_token[:, None] & mask_dim[None, :],
|
||||||
)
|
other=0.0,
|
||||||
accumulator += tmp
|
|
||||||
accumulator = accumulator * routed_scaling_factor
|
|
||||||
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
|
||||||
tl.store(
|
|
||||||
store_t_ptr,
|
|
||||||
accumulator.to(input_ptr.dtype.element_ty),
|
|
||||||
mask=offs_dim < dim_end,
|
|
||||||
)
|
)
|
||||||
|
accumulator += tile.to(tl.float32)
|
||||||
|
accumulator *= routed_scaling_factor
|
||||||
|
|
||||||
|
# -------- Write back --------
|
||||||
|
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
||||||
|
tl.store(
|
||||||
|
store_ptrs,
|
||||||
|
accumulator.to(input_ptr.dtype.element_ty),
|
||||||
|
mask=mask_token[:, None] & mask_dim[None, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||||
def moe_sum_reduce(
|
def moe_sum_reduce(
|
||||||
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||||
):
|
):
|
||||||
@@ -66,7 +69,7 @@ def moe_sum_reduce(
|
|||||||
BLOCK_M = 1
|
BLOCK_M = 1
|
||||||
BLOCK_DIM = 2048
|
BLOCK_DIM = 2048
|
||||||
NUM_STAGE = 1
|
NUM_STAGE = 1
|
||||||
num_warps = 8
|
num_warps = 16
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
triton.cdiv(token_num, BLOCK_M),
|
triton.cdiv(token_num, BLOCK_M),
|
||||||
|
|||||||
@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
|
|||||||
token_block_id = tl.program_id(0)
|
token_block_id = tl.program_id(0)
|
||||||
dim_block_id = tl.program_id(1)
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
token_start = token_block_id * BLOCK_M
|
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||||
|
|
||||||
dim_start = dim_block_id * BLOCK_DIM
|
mask_token = offs_token < token_num
|
||||||
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
mask_dim = offs_dim < hidden_dim
|
||||||
|
|
||||||
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
||||||
|
|
||||||
for token_index in range(token_start, token_end):
|
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
||||||
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
|
||||||
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
tile = tl.load(
|
||||||
tmp = tl.load(
|
base_ptrs + i * input_stride_1,
|
||||||
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
mask=mask_token[:, None] & mask_dim[None, :],
|
||||||
)
|
other=0.0,
|
||||||
accumulator += tmp
|
|
||||||
accumulator = accumulator * routed_scaling_factor
|
|
||||||
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
|
||||||
tl.store(
|
|
||||||
store_t_ptr,
|
|
||||||
accumulator.to(input_ptr.dtype.element_ty),
|
|
||||||
mask=offs_dim < dim_end,
|
|
||||||
)
|
)
|
||||||
|
accumulator += tile.to(tl.float32)
|
||||||
|
accumulator *= routed_scaling_factor
|
||||||
|
|
||||||
|
# -------- Write back --------
|
||||||
|
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
||||||
|
tl.store(
|
||||||
|
store_ptrs,
|
||||||
|
accumulator.to(input_ptr.dtype.element_ty),
|
||||||
|
mask=mask_token[:, None] & mask_dim[None, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def moe_sum_reduce_triton(
|
def moe_sum_reduce_triton(
|
||||||
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
|
|||||||
BLOCK_M = 1
|
BLOCK_M = 1
|
||||||
BLOCK_DIM = 2048
|
BLOCK_DIM = 2048
|
||||||
NUM_STAGE = 1
|
NUM_STAGE = 1
|
||||||
num_warps = 8
|
num_warps = 16
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
triton.cdiv(token_num, BLOCK_M),
|
triton.cdiv(token_num, BLOCK_M),
|
||||||
|
|||||||
Reference in New Issue
Block a user