diff --git a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index 4f9e3b959..b8b23c427 100644 --- a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -195,6 +195,176 @@ void sm100_fp8_blockwise_dispatch_shape( } } +template < + typename OutType, + typename MmaTileShape, + typename PerSmTileShape, + typename EpilogueTileShape, + typename ScalesPerTile, + int TileSizeM_ = 128, + class ClusterShape = Shape<_1, _1, _1>> +void launch_sm120_fp8_blockwise_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { + using ElementBlockScale = float; + + // A matrix configuration + using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of + // elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of + // elements (up to 16 bytes) + + // C/D matrix configuration + using ElementD = OutType; // Element type for D matrix operand + using ElementC = void; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + constexpr int AlignmentD = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of + // elements (up to 16 bytes) + constexpr int AlignmentC = + AlignmentD; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - changed from OpClassBlockScaledTensorOp + + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + cute::UMMA::Major::MN, + cute::UMMA::Major::K>; + // FP8 Block-wise scaling configuration + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel + // schedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + + int m = a.size(0); + int k = a.size(1); + int n = b.size(1); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto scales_a_ptr = static_cast(scales_a.data_ptr()); + auto scales_b_ptr = static_cast(scales_b.data_ptr()); + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideD; + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB}; + + typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, stride_c, c_ptr, stride_c}; + epilogue_args.thread.alpha = 1.0f; + + typename Gemm::Arguments args = { + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + mainloop_args, + epilogue_args, + }; + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) + + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + auto init_status = gemm_op.initialize(args, workspace.get()); + TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status)); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) +} + +template +void sm120_fp8_blockwise_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { + using MmaTileShape = Shape<_128, _128, _128>; + using PerSmTileShape = Shape<_128, _128, _128>; + using EpilogueTileShape = Shape<_128, _64>; + using ScalesPerTile = Shape<_128, _1, _1>; + launch_sm120_fp8_blockwise_scaled_mm( + out, a, b, scales_a, scales_b); +} + torch::Tensor fp8_blockwise_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, @@ -275,6 +445,21 @@ torch::Tensor fp8_blockwise_scaled_mm( } #endif #endif + +#if defined(CUTLASS_ARCH_MMA_SM120A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + if (sm_version == 120) { + if (out_dtype == torch::kBFloat16) { + sm120_fp8_blockwise_dispatch_shape( + out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); + } else { + sm120_fp8_blockwise_dispatch_shape(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); + } + return out_padded.slice(0, 0, original_rows); + } +#endif +#endif + TORCH_CHECK_NOT_IMPLEMENTED( false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); }