Update CUTLASS. Refine KernelSchedule for fp8 (grouped) gemm. (#10491)
This commit is contained in:
@@ -46,7 +46,7 @@ include(FetchContent)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-cutlass
|
repo-cutlass
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||||
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
|
GIT_TAG 57e3cfb47a2d9e0d46eb6335c3dc411498efa198
|
||||||
GIT_SHALLOW OFF
|
GIT_SHALLOW OFF
|
||||||
)
|
)
|
||||||
FetchContent_Populate(repo-cutlass)
|
FetchContent_Populate(repo-cutlass)
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
OperatorClass,
|
OperatorClass,
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using MmaTileShape = Shape<_128, _32, _128>;
|
using MmaTileShape = Shape<_128, _32, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
// TODO: Check Pingpong or Cooperative
|
// TODO: Check Pingpong or Cooperative
|
||||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
using ScaleConfig =
|
using ScaleConfig =
|
||||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||||
@@ -475,7 +475,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using ElementA = cutlass::float_e4m3_t;
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
using MmaTileShape = Shape<_64, _128, _128>;
|
using MmaTileShape = Shape<_64, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
using ScaleConfig =
|
using ScaleConfig =
|
||||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||||
@@ -487,7 +487,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using ElementA = cutlass::float_e4m3_t;
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
using MmaTileShape = Shape<_128, _128, _128>;
|
using MmaTileShape = Shape<_128, _128, _128>;
|
||||||
using ClusterShape = Shape<_1, _2, _1>;
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||||
using ScaleConfig =
|
using ScaleConfig =
|
||||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||||
|
|||||||
Reference in New Issue
Block a user