From e41549c3d6d8ab4dc888dcb2de4a220e2a16567f Mon Sep 17 00:00:00 2001 From: saltyfish66 <38240284+saltyfish66@users.noreply.github.com> Date: Thu, 3 Apr 2025 15:07:32 +0800 Subject: [PATCH] fix: fix illegal cuda memory access at fused_moe_kernel (#4727) Co-authored-by: yuethe --- python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index fbf676b8c..946d194a1 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -152,6 +152,7 @@ def fused_moe_kernel( return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + offs_token = offs_token.to(tl.int64) token_mask = offs_token < num_valid_tokens offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N