From 25e1816eff104da56f97ce494e255306603fe2f6 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Mon, 17 Mar 2025 03:16:30 +0800 Subject: [PATCH] fix custom allreduce performance/accuracy problem (#4477) --- .../csrc/allreduce/trt_reduce_internal.cu | 25 +++++-------------- sgl-kernel/include/trt_reduce_internal.cuh | 2 +- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu index f1ee5d40e..283e1e8ad 100644 --- a/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu +++ b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu @@ -182,8 +182,9 @@ __inline__ __device__ void block_barrier( } } } - - __syncthreads(); + if constexpr (start || need_fence) { + __syncthreads(); + } } template @@ -262,6 +263,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc // Store to the destination buffer. *reinterpret_cast(&reinterpret_cast(params.local_output_buffer_ptr)[iter_offset]) = sums.packed; } + block_barrier( + params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); } template @@ -437,24 +440,8 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0); size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE); - /* threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); - blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); - */ - while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { - blocks_per_grid += 1; - } - - threads_per_block = total_threads / blocks_per_grid; - - // NOTE: need to adjust here - if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) { - size_t iter_factor = 1; - while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { - iter_factor += 1; - } - blocks_per_grid /= iter_factor; - } + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); params.elts_per_rank = params.elts_total / params.ranks_per_node; params.rank_offset = params.local_rank * params.elts_per_rank; params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread); diff --git a/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/include/trt_reduce_internal.cuh index c670c994d..9fec59b65 100644 --- a/sgl-kernel/include/trt_reduce_internal.cuh +++ b/sgl-kernel/include/trt_reduce_internal.cuh @@ -39,7 +39,7 @@ limitations under the License. namespace trt_llm { constexpr size_t WARP_SIZE = 32; -constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32; constexpr size_t MAX_RANKS_PER_NODE = 8; constexpr size_t DEFAULT_BLOCK_SIZE = 512;