[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
@@ -23,7 +23,6 @@ limitations under the License.
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cuda/functional>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
@@ -34,16 +33,6 @@ limitations under the License.
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
// Define reduction operators based on CUDA version
|
||||
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
|
||||
#if CUDA_VERSION >= 12090
|
||||
using MaxReduceOp = cuda::maximum<>;
|
||||
using MinReduceOp = cuda::minimum<>;
|
||||
#else
|
||||
using MaxReduceOp = cub::Max;
|
||||
using MinReduceOp = cub::Min;
|
||||
#endif
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
@@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
@@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
@@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
|
||||
Reference in New Issue
Block a user