From e3b8a72291af1fc807ef99553b2b7fac20c6edf8 Mon Sep 17 00:00:00 2001 From: xutizhou Date: Sun, 18 May 2025 08:01:42 +0800 Subject: [PATCH] [fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348) --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 8ee41c06e..8c005527a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2( offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) mask_s = offset_in_s < SCALE_HIDDEN_SIZE - for token_id in range(start_token_id, total_token_num, grid_num): + for token_id_int32 in range(start_token_id, total_token_num, grid_num): + token_id = token_id_int32.to(tl.int64) to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy_s = tl.load( recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s ) - for topk_index in tl.range(0, topk_num, 1, num_stages=4): + for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4): + topk_index = topk_idx_int32.to(tl.int64) expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) if expert_id >= 0: - dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index = dest_token_index_int32.to(tl.int64) + tl.store( output_index + token_id * output_index_stride0 + topk_index, - dest_token_index, + dest_token_index_int32, ) output_tensor_ptr = ( output_tensor + dest_token_index * output_tensor_stride0 @@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather( topk_num: tl.constexpr, BLOCK_D: tl.constexpr, ): - cur_block = tl.program_id(0) - start_cur_token = tl.program_id(1) + cur_block_int32 = tl.program_id(0) + cur_block = cur_block_int32.to(tl.int64) + + start_cur_token_int32 = tl.program_id(1) + grid_num = tl.num_programs(1) - for cur_token in range(start_cur_token, total_token_num, grid_num): + for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num): + cur_token = cur_token_int32.to(tl.int64) + off_d = tl.arange(0, BLOCK_D) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) - for topk_index in range(0, topk_num): + + for topk_index_int32 in range(0, topk_num): + topk_index = topk_index_int32.to(tl.int64) + expert_id = tl.load( recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index ) if expert_id >= 0: - source_token_index = tl.load( + source_token_index_int32 = tl.load( input_index + cur_token * input_index_stride0 + topk_index ) + source_token_index = source_token_index_int32.to(tl.int64) + acc_weight = tl.load( recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index )