[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)

This commit is contained in:
henryg
2025-08-13 21:59:22 -07:00
committed by GitHub
parent 733446dd36
commit 841810f227
3 changed files with 177 additions and 27 deletions

View File

@@ -48,6 +48,7 @@ limitations under the License.
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "math.hpp"
#include "utils.h"
using namespace cute;
@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using CTAShape = Shape<_256, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_2, _2, _1>;
using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_2, _1, _1>;
using CTAShape64 = Shape<_64, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;
using CTAShape16 = Shape<_64, _64, _128>;
using ClusterShape16 = Shape<_1, _4, _1>;
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = void;
@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
using ElementOutput = OutType;
using AccumElementType = float;
// Gemm type with bias
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
// Gemm type without bias
using GemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
// next power of 2 (minimum 16)
uint32_t const m = a.size(0);
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (bias) {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (256, inf]
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
}
} else {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
}
}
}