Refine pre_reorder_triton_kernel slightly to improve performance (#6627)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -184,8 +184,10 @@ def pre_reorder_triton_kernel(
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
@@ -197,7 +199,7 @@ def pre_reorder_triton_kernel(
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
out_data = (in_data * scale).to(OutDtype)
|
||||
@@ -481,8 +483,11 @@ def post_reorder_triton_kernel(
|
||||
|
||||
computed = False
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
@@ -499,7 +504,7 @@ def post_reorder_triton_kernel(
|
||||
|
||||
if computed == False:
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
tl.store(
|
||||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
||||
|
||||
Reference in New Issue
Block a user