[Fix] illegal sync based on undefined behaviour (#9620)
Signed-off-by: Devashish Lal <devashish@rivosinc.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
template <int THREADS_PER_SUBWARP>
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||
|
||||
static_assert(
|
||||
(THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1,
|
||||
|
||||
Reference in New Issue
Block a user