sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)
This commit is contained in:
@@ -30,13 +30,16 @@
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
|
||||
template <
|
||||
typename SchedulerType,
|
||||
typename OutType,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename ScaleGranularity>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
|
||||
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
SchedulerType>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape(
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using ScaleGranularity = Shape<_1, _128, _128>;
|
||||
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
if (k > 3 * n) {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<
|
||||
cutlass::gemm::StreamKScheduler,
|
||||
OutType,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<
|
||||
cutlass::gemm::PersistentScheduler,
|
||||
OutType,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version == 90) {
|
||||
torch::Tensor scales_b_contiguous = scales_b.contiguous();
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
|
||||
} else {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user