cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)
This commit is contained in:
@@ -43,7 +43,7 @@ include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
GIT_TAG 5e497243f7ad13a2aa842143f9b10bbb23d98292
|
||||
GIT_TAG e94e888df3551224738bfa505787b515eae8352f
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
@@ -34,12 +34,7 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename SchedulerType,
|
||||
typename OutType,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename ScaleGranularity>
|
||||
template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
|
||||
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
|
||||
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using ScaleGranularity = Shape<_1, _128, _128>;
|
||||
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
if (k > 3 * n) {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<
|
||||
cutlass::gemm::StreamKScheduler,
|
||||
OutType,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
||||
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<
|
||||
cutlass::gemm::PersistentScheduler,
|
||||
OutType,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
||||
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user