Refine pre_reorder_triton_kernel slightly to improve performance (#6627)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-05-28 15:15:23 +08:00
committed by GitHub
parent f4a8987f69
commit c087ddd686
2 changed files with 109 additions and 4 deletions

View File

@@ -184,8 +184,10 @@ def pre_reorder_triton_kernel(
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
vec = tl.arange(0, BLOCK_SIZE)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
@@ -197,7 +199,7 @@ def pre_reorder_triton_kernel(
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
offset = start_offset + vec
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
@@ -481,8 +483,11 @@ def post_reorder_triton_kernel(
computed = False
store_ptr = output_ptr + src_idx * hidden_size
vec = tl.arange(0, BLOCK_SIZE)
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
offset = start_offset + vec
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
@@ -499,7 +504,7 @@ def post_reorder_triton_kernel(
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
offset = start_offset + vec
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask