[fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348)
This commit is contained in:
@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
|
|||||||
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||||
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
||||||
|
|
||||||
for token_id in range(start_token_id, total_token_num, grid_num):
|
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
||||||
|
token_id = token_id_int32.to(tl.int64)
|
||||||
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
||||||
to_copy_s = tl.load(
|
to_copy_s = tl.load(
|
||||||
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
||||||
)
|
)
|
||||||
|
|
||||||
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
||||||
|
topk_index = topk_idx_int32.to(tl.int64)
|
||||||
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
||||||
if expert_id >= 0:
|
if expert_id >= 0:
|
||||||
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
|
||||||
|
dest_token_index = dest_token_index_int32.to(tl.int64)
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
output_index + token_id * output_index_stride0 + topk_index,
|
output_index + token_id * output_index_stride0 + topk_index,
|
||||||
dest_token_index,
|
dest_token_index_int32,
|
||||||
)
|
)
|
||||||
output_tensor_ptr = (
|
output_tensor_ptr = (
|
||||||
output_tensor + dest_token_index * output_tensor_stride0
|
output_tensor + dest_token_index * output_tensor_stride0
|
||||||
@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
|
|||||||
topk_num: tl.constexpr,
|
topk_num: tl.constexpr,
|
||||||
BLOCK_D: tl.constexpr,
|
BLOCK_D: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_block = tl.program_id(0)
|
cur_block_int32 = tl.program_id(0)
|
||||||
start_cur_token = tl.program_id(1)
|
cur_block = cur_block_int32.to(tl.int64)
|
||||||
|
|
||||||
|
start_cur_token_int32 = tl.program_id(1)
|
||||||
|
|
||||||
grid_num = tl.num_programs(1)
|
grid_num = tl.num_programs(1)
|
||||||
|
|
||||||
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
|
||||||
|
cur_token = cur_token_int32.to(tl.int64)
|
||||||
|
|
||||||
off_d = tl.arange(0, BLOCK_D)
|
off_d = tl.arange(0, BLOCK_D)
|
||||||
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||||
for topk_index in range(0, topk_num):
|
|
||||||
|
for topk_index_int32 in range(0, topk_num):
|
||||||
|
topk_index = topk_index_int32.to(tl.int64)
|
||||||
|
|
||||||
expert_id = tl.load(
|
expert_id = tl.load(
|
||||||
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
||||||
)
|
)
|
||||||
if expert_id >= 0:
|
if expert_id >= 0:
|
||||||
source_token_index = tl.load(
|
source_token_index_int32 = tl.load(
|
||||||
input_index + cur_token * input_index_stride0 + topk_index
|
input_index + cur_token * input_index_stride0 + topk_index
|
||||||
)
|
)
|
||||||
|
source_token_index = source_token_index_int32.to(tl.int64)
|
||||||
|
|
||||||
acc_weight = tl.load(
|
acc_weight = tl.load(
|
||||||
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user