Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)

This commit is contained in:
Qi Yuhang
2025-08-25 14:24:43 +08:00
committed by GitHub
parent a0b22f2f17
commit fda4792620
5 changed files with 104 additions and 134 deletions

View File

@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
struct MmaConfig0 {
struct MmaConfigSmallM {
// Swap A/B
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>;
// TODO: Check Pingpong or Cooperative
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfigH20LargeK {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfig1 {
struct MmaConfigHx00AndH20SmallK {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
// [NOTE] default for H20
struct MmaConfigH20_default {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
const std::string H20_device_type_str = "NVIDIA H20";
bool is_h20_device = isDeviceType(H20_device_type_str);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
if (is_h20_device) {
using execute_gemm_config = MmaConfigH20_default;
run_get_group_gemm_starts<
execute_gemm_config::LayoutSFA,
execute_gemm_config::LayoutSFB,
execute_gemm_config::ScaleConfig>(
if (a.size(0) <= 2048) {
run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
b_t,
a_t,
output_t,
scales_b_t,
scales_a_t,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
problem_sizes_transpose,
true);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose,
expert_offsets,
workspace);
output = output_t.t();
} else {
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
if (is_h20_device && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
run_get_group_gemm_starts<
MmaConfigH20LargeK::LayoutSFA,
MmaConfigH20LargeK::LayoutSFB,
MmaConfigH20LargeK::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
workspace);
} else {
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
run_get_group_gemm_starts<
MmaConfigHx00AndH20SmallK::LayoutSFA,
MmaConfigHx00AndH20SmallK::LayoutSFB,
MmaConfigHx00AndH20SmallK::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,