Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)
This commit is contained in:
@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
struct MmaConfig0 {
|
||||
struct MmaConfigSmallM {
|
||||
// Swap A/B
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
// TODO: Check Pingpong or Cooperative
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfigH20LargeK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfig1 {
|
||||
struct MmaConfigHx00AndH20SmallK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
// [NOTE] default for H20
|
||||
struct MmaConfigH20_default {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
torch::Tensor output_t = output.t();
|
||||
torch::Tensor a_t = a.t();
|
||||
torch::Tensor b_t = b.transpose(1, 2);
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
const std::string H20_device_type_str = "NVIDIA H20";
|
||||
bool is_h20_device = isDeviceType(H20_device_type_str);
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
|
||||
if (is_h20_device) {
|
||||
using execute_gemm_config = MmaConfigH20_default;
|
||||
run_get_group_gemm_starts<
|
||||
execute_gemm_config::LayoutSFA,
|
||||
execute_gemm_config::LayoutSFB,
|
||||
execute_gemm_config::ScaleConfig>(
|
||||
if (a.size(0) <= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
b_t,
|
||||
a_t,
|
||||
output_t,
|
||||
scales_b_t,
|
||||
scales_a_t,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
|
||||
problem_sizes_transpose,
|
||||
true);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else {
|
||||
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||
if (is_h20_device && a.size(1) > 128) {
|
||||
// For H20 with K > 128, use Pingpong Schedule
|
||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigH20LargeK::LayoutSFA,
|
||||
MmaConfigH20LargeK::LayoutSFB,
|
||||
MmaConfigH20LargeK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
workspace);
|
||||
} else {
|
||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFA,
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFB,
|
||||
MmaConfigHx00AndH20SmallK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
|
||||
Reference in New Issue
Block a user