fix custom allreduce performance/accuracy problem (#4477)
This commit is contained in:
@@ -182,8 +182,9 @@ __inline__ __device__ void block_barrier(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if constexpr (start || need_fence) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||||
@@ -262,6 +263,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
|
|||||||
// Store to the destination buffer.
|
// Store to the destination buffer.
|
||||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||||
}
|
}
|
||||||
|
block_barrier<false>(
|
||||||
|
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||||
@@ -437,24 +440,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
|||||||
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
|
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
|
||||||
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
|
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
|
||||||
|
|
||||||
/*
|
|
||||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||||
*/
|
|
||||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
|
|
||||||
blocks_per_grid += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
threads_per_block = total_threads / blocks_per_grid;
|
|
||||||
|
|
||||||
// NOTE: need to adjust here
|
|
||||||
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
|
|
||||||
size_t iter_factor = 1;
|
|
||||||
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
|
|
||||||
iter_factor += 1;
|
|
||||||
}
|
|
||||||
blocks_per_grid /= iter_factor;
|
|
||||||
}
|
|
||||||
params.elts_per_rank = params.elts_total / params.ranks_per_node;
|
params.elts_per_rank = params.elts_total / params.ranks_per_node;
|
||||||
params.rank_offset = params.local_rank * params.elts_per_rank;
|
params.rank_offset = params.local_rank * params.elts_per_rank;
|
||||||
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
|
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace trt_llm {
|
namespace trt_llm {
|
||||||
constexpr size_t WARP_SIZE = 32;
|
constexpr size_t WARP_SIZE = 32;
|
||||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
|
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
|
||||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||||
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user