diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index 1944e6d37..4c1d96a6a 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -8,7 +8,7 @@ template __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { - unsigned mask = 0xffff; + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; static_assert( (THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1,