From 439f65809f7c917165cbb962d7a6bb5167ecdcf9 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 6 Jan 2025 21:59:31 +0800 Subject: [PATCH] Fix sgl-kernel cu118 compile issue (#2750) --- sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index b4d17ded1..a6f2d5216 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -302,8 +302,10 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); } +#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); +#endif #endif block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, @@ -350,10 +352,11 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc *reinterpret_cast(&local_output_buffer[offset_rank]) = *reinterpret_cast(&buffers[ii][offset_rank]); } } - +#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif +#endif } ////////////////////////////////////////////////////////////////////////////////////////////////////