From 5f6756b038ff5de318adbe2d8272ba1e8dc980c5 Mon Sep 17 00:00:00 2001 From: Morpheus Guo Date: Sun, 13 Jul 2025 04:42:36 +0800 Subject: [PATCH] [BugFix] fix pre_reorder_triton_kernel default int32 issue (#7814) --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 7f9bdc748..d3ec90a7c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -236,7 +236,8 @@ def pre_reorder_triton_kernel( ): OutDtype = gateup_input_ptr.dtype.element_ty - src_idx = tl.program_id(0) + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size @@ -255,7 +256,8 @@ def pre_reorder_triton_kernel( else: scale = 1.0 - dst_idx = tl.load(src2dst_ptr + idx) + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec