From f076328bb7a6da81c75e579433676e27a02f4220 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 13 Feb 2025 16:47:00 +0800 Subject: [PATCH] fix moe_align_kernel shm init not sync bug (#3534) --- sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index b5326409e..6346efbd3 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -53,6 +53,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int } } + __syncthreads(); + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread;