[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)
This commit is contained in:
@@ -48,6 +48,7 @@ limitations under the License.
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using CTAShape = Shape<_256, _128, _64>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using CTAShapeDefault = Shape<_256, _128, _64>;
|
||||
using ClusterShapeDefault = Shape<_2, _2, _1>;
|
||||
|
||||
using CTAShape256 = Shape<_128, _128, _128>;
|
||||
using ClusterShape256 = Shape<_2, _1, _1>;
|
||||
|
||||
using CTAShape64 = Shape<_64, _64, _128>;
|
||||
using ClusterShape64 = Shape<_1, _1, _1>;
|
||||
|
||||
using CTAShape16 = Shape<_64, _64, _128>;
|
||||
using ClusterShape16 = Shape<_1, _4, _1>;
|
||||
|
||||
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileSchedulerType = void;
|
||||
@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
|
||||
// Gemm type with bias
|
||||
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShapeDefault,
|
||||
ClusterShapeDefault,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape256,
|
||||
ClusterShape256,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape64,
|
||||
ClusterShape64,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape16,
|
||||
ClusterShape16,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
|
||||
// Gemm type without bias
|
||||
using GemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShapeDefault,
|
||||
ClusterShapeDefault,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm256 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape256,
|
||||
ClusterShape256,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm64 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape64,
|
||||
ClusterShape64,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm16 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape16,
|
||||
ClusterShape16,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
|
||||
// next power of 2 (minimum 16)
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (bias) {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (64, 256]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
// m in (256, inf]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (64, 256]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user