Support Blackwell Block Scale FP8 Gemm (#4278)

This commit is contained in:
Elfie Guo
2025-03-12 14:17:11 -07:00
committed by GitHub
parent 10b544ae9b
commit 7c86671131
3 changed files with 207 additions and 2 deletions

View File

@@ -17,6 +17,8 @@
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/util/tensor_view_io.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
@@ -154,6 +156,141 @@ void launch_sm90_fp8_blockwise_scaled_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
template <
typename OutType,
typename MmaTileShape,
typename PerSmTileShape,
typename EpilogueTileShape,
typename ScalesPerTile,
int TileSizeM_ = 128,
class ClusterShape = Shape<_1, _1, _1>>
void launch_sm100_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using ElementB = ElementAB;
using ElementC = void;
using ElementD = OutType;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::RowMajor;
using LayoutC = LayoutD;
// This means both SFA and SFB are column-major.
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM,
ScaleGranularityN,
ScaleGranularityK,
cute::UMMA::Major::MN,
cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementBlockScale = float;
using ElementCompute = float;
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
cutlass::arch::OpClassTensorOp,
PerSmTileShape,
ClusterShape,
EpilogueTileShape,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto scales_a_ptr = static_cast<float*>(scales_a.data_ptr());
auto scales_b_ptr = static_cast<float*>(scales_b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
using StrideD = typename GemmKernel::StrideD;
using StrideC = typename GemmKernel::StrideD;
StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB};
typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride};
epilogue_args.thread.alpha = 1.0f;
typename GemmKernel::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args};
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto init_status = gemm_op.initialize(args, workspace.get());
TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status));
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
@@ -166,6 +303,30 @@ void sm90_fp8_blockwise_dispatch_shape(
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
}
template <typename OutType>
void sm100_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
if (a.size(0) <= 128) {
using MmaTileShape = Shape<_64, _128, _128>;
using PerSmTileShape = Shape<_64, _128, _128>;
using EpilogueTileShape = Shape<_64, _64>;
using ScalesPerTile = Shape<_64, _1, _1>;
launch_sm100_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
out, a, b, scales_a, scales_b);
} else {
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape = Shape<_128, _128, _128>;
using EpilogueTileShape = Shape<_128, _64>;
using ScalesPerTile = Shape<_128, _1, _1>;
launch_sm100_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
out, a, b, scales_a, scales_b);
}
}
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
@@ -210,7 +371,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (sm_version == 90) {
if (out_dtype == torch::kBFloat16) {
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
} else {
@@ -221,6 +382,25 @@ torch::Tensor fp8_blockwise_scaled_mm(
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version == 100) {
int64_t original_rows = mat_a.size(0);
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
if (out_dtype == torch::kBFloat16) {
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
} else {
sm100_fp8_blockwise_dispatch_shape<cutlass::half_t>(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
}
return out_padded.slice(0, 0, original_rows);
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
}