sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)

This commit is contained in:
Yi Zhang
2025-04-10 02:47:04 +08:00
committed by GitHub
parent 7f875f1293
commit ebf495f013
6 changed files with 86 additions and 923 deletions

View File

@@ -30,13 +30,16 @@
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h"
using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
template <
typename SchedulerType,
typename OutType,
typename TileShape,
typename ClusterShape,
typename ScaleGranularity>
void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC;
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
SchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
using ClusterShape = Shape<_1, _2, _1>;
using ScaleGranularity = Shape<_1, _128, _128>;
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::StreamKScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
} else {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::PersistentScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
}
}
template <typename OutType>
@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version == 90) {
torch::Tensor scales_b_contiguous = scales_b.contiguous();
if (out_dtype == torch::kBFloat16) {
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
} else {
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b);
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
}
return out;
}