From ad4e58bf67ec833ff4d036af5129ec6e1633efc4 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 20 Mar 2025 14:40:28 -0500 Subject: [PATCH] Support fp8 gemm for blackwell (#4558) --- sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu | 287 ++++++++++++++++++++++++ 1 file changed, 287 insertions(+) diff --git a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index 64731ebe4..d3bc610f3 100644 --- a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape( } #endif +#if defined CUDA_VERSION && CUDA_VERSION >= 12080 +template < + typename ElementType, + typename OutElementType, + typename AccumElementType, + typename CTAShape, + typename ClusterShape, + typename MainloopScheduleType, + typename EpilogueScheduleType, + typename TileSchedulerType = void, + bool WithBias = false> +struct DeviceGemmFp8RowwiseSm100 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + using TileShape = CTAShape; + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ElementComputeEpilogue = float; + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + OutElementType, + OutElementType, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion:: + Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using Compute1MulAdd = cutlass::epilogue::fusion:: + Sm90Compute; + using Compute1Mul = cutlass::epilogue::fusion:: + Sm90Compute; + + using EVTCompute = typename std::conditional_t< + WithBias, + cutlass::epilogue::fusion::Sm90EVT, + cutlass::epilogue::fusion::Sm90EVT>; + using ArgumentType = typename EVTCompute::Arguments; + // MMA type + using ElementAccumulator = AccumElementType; + + // Epilogue types + using ElementCompute = float; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + OutElementType, + LayoutD, + AlignmentD, + EpilogueScheduleType, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + ElementType, + LayoutA, + AlignmentA, + ElementType, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType>::CollectiveOp; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v); + return Arguments{data_ptr}; + } + + public: + static ArgumentType prepare_args( + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias = std::nullopt) { + auto a_args = args_from_tensor(a_scales); + auto b_args = args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + + if constexpr (WithBias) { + auto bias_args = args_from_tensor(bias.value()); + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } else { + return ArgumentType{a_args, evt0_args, {}}; + } + } +}; + +template +typename GemmType::Gemm::Arguments prepare_sm100_fp8_args( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + using Gemm = typename GemmType::Gemm; + using ElementT = typename Gemm::ElementA; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = StrideC; + using StrideAux = StrideC; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + StrideAux aux_stride = stride_d; + + typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b}; + + typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1}; + cutlass::KernelHardwareInfo hw_info; + typename GemmKernel::TileSchedulerArguments scheduler = {}; + + auto ptr_c = static_cast(out.data_ptr()); + + auto prepare_epilogue_args = [&](const c10::optional& bias = c10::nullopt) { + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided."); + return typename GemmKernel::EpilogueArguments{ + GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d}; + } else { + return typename GemmKernel::EpilogueArguments{ + GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d}; + } + }; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, + mainloop_args, + prepare_epilogue_args(bias), + hw_info, + scheduler}; + return args; +} + +template +void launch_sm100_fp8_scaled_mm( + torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm100_fp8_args(out, a, b, scales_a, scales_b, bias); + + typename Gemm::Gemm gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm100_fp8_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + using CTAShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileSchedulerType = void; + + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + + if (bias) { + using Gemm = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>; + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>; + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm100_fp8_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + return sm100_fp8_dispatch_bias(out, a, b, scales_a, scales_b, bias); +} +#endif + torch::Tensor fp8_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, @@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm( auto sm_version = getSMVersion(); +#if defined CUDA_VERSION && CUDA_VERSION >= 12080 + if (sm_version >= 100) { + if (out_dtype == torch::kBFloat16) { + sm100_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm100_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + #if defined CUDA_VERSION && CUDA_VERSION >= 12000 if (sm_version >= 90) { if (out_dtype == torch::kBFloat16) {