Optimize Hopper CUTLASS FP8 Blockwise Grouped GEMM Kernel in Small K Scenario (#7782)
This commit is contained in:
@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator>;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using CustomEVTIdentity = // acc
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
typename ScheduleConfig::EpilogueSchedule,
|
||||
FusionOperation>::CollectiveOp;
|
||||
CustomEVTIdentity>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
struct MmaConfig {
|
||||
struct MmaConfig0 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfig1 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
|
||||
run_get_group_gemm_starts<MmaConfig::LayoutSFA, MmaConfig::LayoutSFB, MmaConfig::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
if (a.size(1) > 128) {
|
||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
// Small K
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (sm_version == 90 && a.size(1) > 256) {
|
||||
if (sm_version == 90) {
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output,
|
||||
@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
can_implement,
|
||||
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ",
|
||||
sm_version,
|
||||
a.size(1));
|
||||
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user