[Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127)
This commit is contained in:
@@ -40,27 +40,21 @@ using namespace cute;
|
|||||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
// Kernel Perf config
|
// Kernel Perf config
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct KernelTraits;
|
struct KernelTraits {
|
||||||
|
using MmaTileShape = Shape<_256, _256, _256>;
|
||||||
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
|
using EpilogueTile = Shape<_128, _64>;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||||
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct KernelTraits<float> {
|
struct KernelTraits<float> {
|
||||||
using MmaTileShape = Shape<_128, _128, _256>;
|
using MmaTileShape = Shape<_128, _128, _256>;
|
||||||
using ClusterShape = Shape<_1, _1, _1>;
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
};
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||||
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||||
template <>
|
|
||||||
struct KernelTraits<cutlass::half_t> {
|
|
||||||
using MmaTileShape = Shape<_256, _256, _256>;
|
|
||||||
using ClusterShape = Shape<_4, _4, _1>;
|
|
||||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct KernelTraits<cutlass::bfloat16_t> {
|
|
||||||
using MmaTileShape = Shape<_256, _256, _256>;
|
|
||||||
using ClusterShape = Shape<_4, _4, _1>;
|
|
||||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -90,23 +84,26 @@ struct Fp4GemmSm100 {
|
|||||||
// Kernel Perf config
|
// Kernel Perf config
|
||||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
|
||||||
|
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
|
||||||
|
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
|
||||||
|
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
OperatorClass,
|
cutlass::arch::OpClassTensorOp,
|
||||||
PerSmTileShape_MNK,
|
MmaTileShape,
|
||||||
ClusterShape,
|
ClusterShape,
|
||||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
EpilogueTile,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementC,
|
void,
|
||||||
LayoutCTag,
|
LayoutCTag,
|
||||||
AlignmentC,
|
AlignmentC,
|
||||||
ElementD,
|
ElementD,
|
||||||
LayoutDTag,
|
LayoutDTag,
|
||||||
AlignmentD,
|
AlignmentD,
|
||||||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
|
EpilogueSchedule,
|
||||||
|
cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>>::CollectiveOp;
|
||||||
|
|
||||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
@@ -122,7 +119,7 @@ struct Fp4GemmSm100 {
|
|||||||
ClusterShape,
|
ClusterShape,
|
||||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
MainloopSchedule>::CollectiveOp;
|
||||||
|
|
||||||
using GemmKernel =
|
using GemmKernel =
|
||||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||||
@@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
stride_D}};
|
stride_D}};
|
||||||
auto& fusion_args = arguments.epilogue.thread;
|
auto& fusion_args = arguments.epilogue.thread;
|
||||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
|
||||||
|
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
|
||||||
|
} else {
|
||||||
|
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
|
||||||
|
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
|
||||||
|
}
|
||||||
return arguments;
|
return arguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user