Optimize Hopper CUTLASS FP8 Blockwise Grouped GEMM Kernel in Small K Scenario (#7782)
This commit is contained in:
@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
|||||||
|
|
||||||
using ArchTag = cutlass::arch::Sm90;
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator>;
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
|
using CustomEVTIdentity = // acc
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<
|
||||||
|
cutlass::epilogue::fusion::
|
||||||
|
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
|
||||||
|
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
|||||||
LayoutC*,
|
LayoutC*,
|
||||||
AlignmentC,
|
AlignmentC,
|
||||||
typename ScheduleConfig::EpilogueSchedule,
|
typename ScheduleConfig::EpilogueSchedule,
|
||||||
FusionOperation>::CollectiveOp;
|
CustomEVTIdentity>::CollectiveOp;
|
||||||
|
|
||||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
const torch::Tensor& problem_sizes,
|
const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets,
|
const torch::Tensor& expert_offsets,
|
||||||
const torch::Tensor& workspace) {
|
const torch::Tensor& workspace) {
|
||||||
struct MmaConfig {
|
struct MmaConfig0 {
|
||||||
using ElementA = cutlass::float_e4m3_t;
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
using MmaTileShape = Shape<_64, _128, _128>;
|
using MmaTileShape = Shape<_64, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MmaConfig1 {
|
||||||
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
|
using MmaTileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||||
|
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
};
|
||||||
|
|
||||||
int num_experts = (int)expert_offsets.size(0);
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||||
|
|
||||||
run_get_group_gemm_starts<MmaConfig::LayoutSFA, MmaConfig::LayoutSFB, MmaConfig::ScaleConfig>(
|
if (a.size(1) > 128) {
|
||||||
expert_offsets,
|
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||||
a_ptrs,
|
expert_offsets,
|
||||||
b_ptrs,
|
a_ptrs,
|
||||||
out_ptrs,
|
b_ptrs,
|
||||||
a_scales_ptrs,
|
out_ptrs,
|
||||||
b_scales_ptrs,
|
a_scales_ptrs,
|
||||||
a,
|
b_scales_ptrs,
|
||||||
b,
|
a,
|
||||||
output,
|
b,
|
||||||
scales_a,
|
output,
|
||||||
scales_b,
|
scales_a,
|
||||||
layout_sfa,
|
scales_b,
|
||||||
layout_sfb,
|
layout_sfa,
|
||||||
problem_sizes,
|
layout_sfb,
|
||||||
problem_sizes_transpose);
|
problem_sizes,
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig, cutlass::layout::RowMajor>(
|
problem_sizes_transpose);
|
||||||
out_ptrs,
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||||
a_ptrs,
|
out_ptrs,
|
||||||
b_ptrs,
|
a_ptrs,
|
||||||
a_scales_ptrs,
|
b_ptrs,
|
||||||
b_scales_ptrs,
|
a_scales_ptrs,
|
||||||
stride_a,
|
b_scales_ptrs,
|
||||||
stride_b,
|
stride_a,
|
||||||
stride_c,
|
stride_b,
|
||||||
layout_sfa,
|
stride_c,
|
||||||
layout_sfb,
|
layout_sfa,
|
||||||
problem_sizes,
|
layout_sfb,
|
||||||
expert_offsets,
|
problem_sizes,
|
||||||
workspace);
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
} else {
|
||||||
|
// Small K
|
||||||
|
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||||
|
expert_offsets,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
output,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
problem_sizes_transpose);
|
||||||
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||||
|
out_ptrs,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
stride_c,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||||
if (sm_version == 90 && a.size(1) > 256) {
|
if (sm_version == 90) {
|
||||||
if (output.scalar_type() == torch::kBFloat16) {
|
if (output.scalar_type() == torch::kBFloat16) {
|
||||||
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||||
output,
|
output,
|
||||||
@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
can_implement,
|
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||||
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ",
|
|
||||||
sm_version,
|
|
||||||
a.size(1));
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user