support cuda 13.0 and trtllm kernel by Aug 25 2025 (#9495)

This commit is contained in:
Rain Jiang
2025-08-26 23:13:27 -07:00
committed by GitHub
parent 8f7b1c31e8
commit 79e6a8a6ac
13 changed files with 81 additions and 14 deletions

View File

@@ -23,6 +23,7 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
@@ -33,6 +34,16 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif
/// Aligned array type
template <
typename T,
@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
threadData = max(convert_to_float<T>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
if (threadIdx.x == 0) {
float_max = maxElem;
@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;