cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)
This commit is contained in:
@@ -43,7 +43,7 @@ include(FetchContent)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-cutlass
|
repo-cutlass
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||||
GIT_TAG 5e497243f7ad13a2aa842143f9b10bbb23d98292
|
GIT_TAG e94e888df3551224738bfa505787b515eae8352f
|
||||||
GIT_SHALLOW OFF
|
GIT_SHALLOW OFF
|
||||||
)
|
)
|
||||||
FetchContent_Populate(repo-cutlass)
|
FetchContent_Populate(repo-cutlass)
|
||||||
|
|||||||
@@ -34,12 +34,7 @@
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <
|
template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
|
||||||
typename SchedulerType,
|
|
||||||
typename OutType,
|
|
||||||
typename TileShape,
|
|
||||||
typename ClusterShape,
|
|
||||||
typename ScaleGranularity>
|
|
||||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
const torch::Tensor& a,
|
const torch::Tensor& a,
|
||||||
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
|||||||
using LayoutD = cutlass::layout::RowMajor;
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
constexpr int AlignmentD = AlignmentC;
|
constexpr int AlignmentD = AlignmentC;
|
||||||
|
|
||||||
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
|
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||||
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
|
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
|
||||||
using ArchTag = cutlass::arch::Sm90;
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
|||||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
using KernelSchedule =
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
|
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
OperatorClass,
|
OperatorClass,
|
||||||
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
|||||||
ArchTag,
|
ArchTag,
|
||||||
OperatorClass,
|
OperatorClass,
|
||||||
ElementA,
|
ElementA,
|
||||||
LayoutA,
|
cute::tuple<LayoutA, LayoutSFA>,
|
||||||
AlignmentA,
|
AlignmentA,
|
||||||
ElementB,
|
ElementB,
|
||||||
LayoutB,
|
cute::tuple<LayoutB, LayoutSFB>,
|
||||||
AlignmentB,
|
AlignmentB,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
TileShape,
|
TileShape,
|
||||||
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
|
|||||||
StrideC stride_c;
|
StrideC stride_c;
|
||||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||||
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
|
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||||
|
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
||||||
|
|
||||||
typename Gemm::Arguments args = {
|
typename Gemm::Arguments args = {
|
||||||
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
|
|||||||
const torch::Tensor& scales_b) {
|
const torch::Tensor& scales_b) {
|
||||||
using TileShape = Shape<_128, _128, _128>;
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
using ClusterShape = Shape<_1, _2, _1>;
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
using ScaleGranularity = Shape<_1, _128, _128>;
|
|
||||||
|
|
||||||
auto k = a.size(1);
|
auto k = a.size(1);
|
||||||
auto n = b.size(1);
|
auto n = b.size(1);
|
||||||
if (k > 3 * n) {
|
if (k > 3 * n) {
|
||||||
launch_sm90_fp8_blockwise_scaled_mm<
|
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
|
||||||
cutlass::gemm::StreamKScheduler,
|
out, a, b, scales_a, scales_b);
|
||||||
OutType,
|
|
||||||
TileShape,
|
|
||||||
ClusterShape,
|
|
||||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
|
||||||
} else {
|
} else {
|
||||||
launch_sm90_fp8_blockwise_scaled_mm<
|
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
|
||||||
cutlass::gemm::PersistentScheduler,
|
out, a, b, scales_a, scales_b);
|
||||||
OutType,
|
|
||||||
TileShape,
|
|
||||||
ClusterShape,
|
|
||||||
ScaleGranularity>(out, a, b, scales_a, scales_b);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user