[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)

This commit is contained in:
Elfie Guo
2025-05-16 13:14:07 -07:00
committed by GitHub
parent 839fb31e5f
commit 6fc9357503
12 changed files with 896 additions and 41 deletions

View File

@@ -1,3 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
// Alignment constraints
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// For fp8 block scale.
// using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
// ScaleGranularityK, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA =
// decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = 1;
// Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales
// sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
hw_info.sm_count = 148;
typename GemmKernel::EpilogueArguments epilogue_args{
{},
nullptr,
@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Use prob_shape in the GEMM arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run();
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void sm100_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) {
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
// Check the first matrix size to decide on the configuration
// Assuming all matrices in the group have similar size characteristics
// bool use_small_config = a[0].size(0) <= 128;
struct MMALargeConfig {
struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfig2 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MMASmallConfig {
struct MmaConfig3 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _16, _128>;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor workspace = torch::empty(100, options_int);
torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
if (a.size(0) <= 512) {
run_get_group_gemm_starts<MMASmallConfig::LayoutSFA, MMASmallConfig::LayoutSFB, MMASmallConfig::ScaleConfig>(
if (a.size(0) <= 512 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
problem_sizes,
problem_sizes_transpose,
true);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMASmallConfig, cutlass::layout::ColumnMajor>(
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
@@ -244,8 +247,8 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets,
workspace);
output = output_t.t();
} else {
run_get_group_gemm_starts<MMALargeConfig::LayoutSFA, MMALargeConfig::LayoutSFB, MMALargeConfig::ScaleConfig>(
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
@@ -261,7 +264,38 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMALargeConfig, cutlass::layout::RowMajor>(
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig2, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
run_get_group_gemm_starts<MmaConfig3::LayoutSFA, MmaConfig3::LayoutSFB, MmaConfig3::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig3, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
*/
void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) {
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm(
TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32");
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor");
TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor");
TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor");
TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor");
TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor");
TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor");
TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor");
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
bool can_implement = false;
auto sm_version = getSMVersion();
@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm(
if (output.scalar_type() == torch::kBFloat16) {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets);
expert_offsets,
workspace);
} else {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets);
expert_offsets,
workspace);
}
can_implement = true;
}