fix custom allreduce performance/accuracy problem (#4477)
This commit is contained in:
@@ -182,8 +182,9 @@ __inline__ __device__ void block_barrier(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if constexpr (start || need_fence) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
@@ -262,6 +263,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
|
||||
// Store to the destination buffer.
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
}
|
||||
block_barrier<false>(
|
||||
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
@@ -437,24 +440,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
||||
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
|
||||
|
||||
/*
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
*/
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
|
||||
blocks_per_grid += 1;
|
||||
}
|
||||
|
||||
threads_per_block = total_threads / blocks_per_grid;
|
||||
|
||||
// NOTE: need to adjust here
|
||||
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
|
||||
size_t iter_factor = 1;
|
||||
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
|
||||
iter_factor += 1;
|
||||
}
|
||||
blocks_per_grid /= iter_factor;
|
||||
}
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
params.elts_per_rank = params.elts_total / params.ranks_per_node;
|
||||
params.rank_offset = params.local_rank * params.elts_per_rank;
|
||||
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
|
||||
|
||||
@@ -39,7 +39,7 @@ limitations under the License.
|
||||
|
||||
namespace trt_llm {
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user