[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
PGFLMG
2025-08-28 03:01:30 +08:00
committed by GitHub
parent 07ee0ab750
commit aa3eba8eb4
25 changed files with 210 additions and 383 deletions

View File

@@ -23,7 +23,6 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
@@ -34,16 +33,6 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif
/// Aligned array type
template <
typename T,
@@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
@@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__
threadData = max(convert_to_float<T>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
@@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;