[fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348)
This commit is contained in:
@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
|
||||
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
||||
|
||||
for token_id in range(start_token_id, total_token_num, grid_num):
|
||||
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
||||
token_id = token_id_int32.to(tl.int64)
|
||||
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
||||
to_copy_s = tl.load(
|
||||
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
||||
)
|
||||
|
||||
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
||||
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
||||
topk_index = topk_idx_int32.to(tl.int64)
|
||||
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
||||
if expert_id >= 0:
|
||||
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
||||
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
|
||||
dest_token_index = dest_token_index_int32.to(tl.int64)
|
||||
|
||||
tl.store(
|
||||
output_index + token_id * output_index_stride0 + topk_index,
|
||||
dest_token_index,
|
||||
dest_token_index_int32,
|
||||
)
|
||||
output_tensor_ptr = (
|
||||
output_tensor + dest_token_index * output_tensor_stride0
|
||||
@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
|
||||
topk_num: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
cur_block = tl.program_id(0)
|
||||
start_cur_token = tl.program_id(1)
|
||||
cur_block_int32 = tl.program_id(0)
|
||||
cur_block = cur_block_int32.to(tl.int64)
|
||||
|
||||
start_cur_token_int32 = tl.program_id(1)
|
||||
|
||||
grid_num = tl.num_programs(1)
|
||||
|
||||
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
||||
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
|
||||
cur_token = cur_token_int32.to(tl.int64)
|
||||
|
||||
off_d = tl.arange(0, BLOCK_D)
|
||||
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||
for topk_index in range(0, topk_num):
|
||||
|
||||
for topk_index_int32 in range(0, topk_num):
|
||||
topk_index = topk_index_int32.to(tl.int64)
|
||||
|
||||
expert_id = tl.load(
|
||||
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
||||
)
|
||||
if expert_id >= 0:
|
||||
source_token_index = tl.load(
|
||||
source_token_index_int32 = tl.load(
|
||||
input_index + cur_token * input_index_stride0 + topk_index
|
||||
)
|
||||
source_token_index = source_token_index_int32.to(tl.int64)
|
||||
|
||||
acc_weight = tl.load(
|
||||
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user