Files
enginex-ascend-910-llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh

54 lines
1.5 KiB
Plaintext
Raw Normal View History

CUDA: Optimize `reduce_rows_f32` kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n (#15132) * Factor out `reduce_rows_f32` from common.cuh This increases iteration cycle speed by not having to recompile every kernel all the time * Hide memory-latency by loop unrolling in reduce_rows_f32 * Further optimizations to `reduce_rows_f32` 1. Increase threadblock size to better hide latency of memory requests. As a consequence of bigger threadblocks, do 2-step summation, using shared memory to communicate results between invocations 2. Use sum_temp array to reduce waits on sum 3. Adjust num_unroll to reflext bigger threadblock 4. Improve default block_dims, increase support for more block_dims * Add perf tests for `reduce_rows_f32` kernel * Add heuristic to toggle 128/512 threads based on sm count Break even point was the minimum of the following multiples. | GPU Model | Nrow SM Count Multiple | | ----------- | ----------- | | RTX 4000 SFF ADA | 2.0x | | RTX 6000 ADA | 2.5x | | RTX PRO 6000 Blackwell Max-Q | 3.04x | | RTX PRO 4500 Blackwell | 3.15x | * Ensure perf gains also for small ncols and large nrows Alternative to this, one could have also made the number of unrollings template-able, but that would require compiling the kernel multiple times, increasing binary size unnecessarily * Modify perf and unit-tests * Apply auto-formatting by clang * Fix CI build failure See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486 Building with VS generator worked though. * Remove sm_count property from `ggml_backend_cuda_context` Requested by @JohannesGaessler, and should fix remaining CI issues as a side-effect * Add CUB-based implementation for GGML_OP_MEAN Currently this branch is only executed for nrows==1 * Add heuristics to execute CUB branch only when it brings perf Heuristics were determined on the following HW: * RTX 4000 SFF ADA * RTX 6000 ADA * RTX PRO 6000 Blackwell Max-Q * RTX PRO 4500 Blackwell * Add unit-test for CUB-based mean Tests should run with CUDA Graphs enabled per default on NVGPUs * Rename `USE_CUB` to `GGML_CUDA_USE_CUB` Suggested by @JohannesGaessler * Unindent Preprocessor directives See https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506
2025-08-13 10:04:46 +02:00
#include "common.cuh"
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template <bool norm>
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
const int num_unroll = 8;
float temp[num_unroll];
float sum_temp[num_unroll] = { 0.0f };
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {
temp[j] = x[row * ncols + i];
} else {
temp[j] = 0;
}
i += blockDim.x;
}
for (int j = 0; j < num_unroll; ++j) {
sum_temp[j] += temp[j];
}
}
for (int j = 0; j < num_unroll; ++j) {
sum += sum_temp[j];
}
// sum up partial sums
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
CUDA: Optimize `reduce_rows_f32` kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n (#15132) * Factor out `reduce_rows_f32` from common.cuh This increases iteration cycle speed by not having to recompile every kernel all the time * Hide memory-latency by loop unrolling in reduce_rows_f32 * Further optimizations to `reduce_rows_f32` 1. Increase threadblock size to better hide latency of memory requests. As a consequence of bigger threadblocks, do 2-step summation, using shared memory to communicate results between invocations 2. Use sum_temp array to reduce waits on sum 3. Adjust num_unroll to reflext bigger threadblock 4. Improve default block_dims, increase support for more block_dims * Add perf tests for `reduce_rows_f32` kernel * Add heuristic to toggle 128/512 threads based on sm count Break even point was the minimum of the following multiples. | GPU Model | Nrow SM Count Multiple | | ----------- | ----------- | | RTX 4000 SFF ADA | 2.0x | | RTX 6000 ADA | 2.5x | | RTX PRO 6000 Blackwell Max-Q | 3.04x | | RTX PRO 4500 Blackwell | 3.15x | * Ensure perf gains also for small ncols and large nrows Alternative to this, one could have also made the number of unrollings template-able, but that would require compiling the kernel multiple times, increasing binary size unnecessarily * Modify perf and unit-tests * Apply auto-formatting by clang * Fix CI build failure See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486 Building with VS generator worked though. * Remove sm_count property from `ggml_backend_cuda_context` Requested by @JohannesGaessler, and should fix remaining CI issues as a side-effect * Add CUB-based implementation for GGML_OP_MEAN Currently this branch is only executed for nrows==1 * Add heuristics to execute CUB branch only when it brings perf Heuristics were determined on the following HW: * RTX 4000 SFF ADA * RTX 6000 ADA * RTX PRO 6000 Blackwell Max-Q * RTX PRO 4500 Blackwell * Add unit-test for CUB-based mean Tests should run with CUDA Graphs enabled per default on NVGPUs * Rename `USE_CUB` to `GGML_CUDA_USE_CUB` Suggested by @JohannesGaessler * Unindent Preprocessor directives See https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506
2025-08-13 10:04:46 +02:00
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}