Add CUTLASS FP8 Blockscale MoE kernel for Hopper architecture (#7278)

Co-authored-by: HydraQYH <QYH820@Outlook.com>
Co-authored-by: TianQiLin666666 <1834987979@qq.com>
This commit is contained in:
ayrnb
2025-07-03 14:27:03 +08:00
committed by GitHub
parent 2ff572e28c
commit 2c4feaf308
3 changed files with 578 additions and 9 deletions

245
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu Executable file → Normal file
View File

@@ -30,6 +30,126 @@
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void launch_sm90_fp8_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs,
const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs,
const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = OutType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC, // Use void to avoid load Matrix C
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentC,
typename ScheduleConfig::EpilogueSchedule,
FusionOperation>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
AlignmentA,
ElementB,
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
AlignmentB,
ElementAccumulator,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
int num_experts = (int)expert_offsets.size(0);
Gemm gemm_op;
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
typename GemmKernel::EpilogueArguments epilogue_args{
{},
nullptr,
static_cast<StrideC*>(stride_c.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void launch_sm100_fp8_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs,
@@ -312,6 +432,74 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
}
}
template <typename OutType>
void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
struct MmaConfig {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
run_get_group_gemm_starts<MmaConfig::LayoutSFA, MmaConfig::LayoutSFB, MmaConfig::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
/**
* @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
* with per-block scaling.
@@ -397,11 +585,6 @@ void fp8_blockwise_scaled_grouped_mm(
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
@@ -455,7 +638,57 @@ void fp8_blockwise_scaled_grouped_mm(
can_implement = true;
}
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if (sm_version == 90 && a.size(1) > 256) {
if (output.scalar_type() == torch::kBFloat16) {
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
can_implement = true;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
can_implement,
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ",
sm_version,
a.size(1));
}