diff --git a/sgl-kernel/csrc/cutlass_extensions/common.hpp b/sgl-kernel/csrc/cutlass_extensions/common.hpp new file mode 100644 index 000000000..912ef8b16 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "cuda_runtime.h" +#include "cutlass/cutlass.h" + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/cutlass_gemm_caller.cuh b/sgl-kernel/csrc/cutlass_extensions/gemm/cutlass_gemm_caller.cuh new file mode 100644 index 000000000..737916d89 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/cutlass_gemm_caller.cuh @@ -0,0 +1,62 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh + +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" + +// clang-format on + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +template +void cutlass_gemm_caller( + torch::Device device, + cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = c10::cuda::current_device(); + hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args, hw_info, scheduler}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp new file mode 100644 index 000000000..6019e4a53 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -0,0 +1,38 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/gemm/dispatch_policy.hpp + +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +////////////////////////////////////////////////////////////////////////////// + +// FP8 related policies (including Blocked Scaled Accumulation) +// `ScaleGranularityM` specifies scaling granularity along M, while zero-value +// `ScaleGranularityM` indicates that scaling granularity is +// `size<0>(TileShape_MNK{})` along M. +template +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp +// specialized dynamic schedule For FP8 kernels with Block Scaling +template < + int Stages_, + class ClusterShape_ = Shape<_1, _1, _1>, + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > +struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute:: + is_same_v>, + "KernelSchedule must be one of the warp specialized policies"); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh b/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh new file mode 100644 index 000000000..8deda43e5 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh @@ -0,0 +1,197 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh" +#include "cutlass_extensions/gemm/dispatch_policy.hpp" + +using namespace cute; + +template < + typename SchedulerType, + typename OutType, + int GroupSizeM_, + int GroupSizeN_, + int GroupSizeK_, + int TileSizeM_ = 128, + class ClusterShape = Shape<_1, _2, _1>> +struct cutlass_3x_gemm_fp8_blockwise { + using GroupSizeM = Int; + using GroupSizeN = Int; + using GroupSizeK = Int; + using TileSizeM = Int; + + static_assert(TileSizeM_ % GroupSizeM_ == 0, "TileSizeM must be a multiple of GroupSizeM"); + + using ElementAB = cutlass::float_e4m3_t; + + // A matrix configuration + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + // B matrix configuration + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // C/D matrix configuration + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ScaleTileShape = Shape<_1, _128, _128>; + using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{})); + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using TileShape = Shape; // Threadblock-level tile size + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT; + + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + SchedulerType>; +}; + +template +void cutlass_gemm_caller_blockwise( + torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using ElementAB = typename Gemm::ElementAB; + using ElementA = ElementAB; + using ElementB = ElementAB; + using ElementD = typename Gemm::ElementD; + using ElementBlockScale = float; + + using ScaleTileShape = Shape<_1, _128, _128>; + using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{})); + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + int m = a.size(0); + int k = a.size(1); + int n = b.size(1); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + + auto a_s_ptr = static_cast(a_scales.data_ptr()); + auto b_s_ptr = static_cast(b_scales.data_ptr()); + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideD = typename GemmKernel::StrideD; + using StrideC = typename GemmKernel::StrideC; + + StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb}; + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::TileSchedulerArguments scheduler; + + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; + + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = + typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = + typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; + + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } + + cutlass_gemm_caller(a.device(), {m, n, k, 1}, mainloop_args, epilogue_args, scheduler); +} + +template +void cutlass_gemm_blockwise_sm90_fp8_dispatch( + torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto k = a.size(1); + auto n = b.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise< + cutlass_3x_gemm_fp8_blockwise>( + out, a, b, a_scales, b_scales); + } +} diff --git a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu old mode 100755 new mode 100644 index 1c082da4e..e69167a4d --- a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -30,138 +30,12 @@ #include #include +#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh" +#include "cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh" #include "utils.h" using namespace cute; -template -void launch_sm90_fp8_blockwise_scaled_mm( - torch::Tensor& out, - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& scales_a, - const torch::Tensor& scales_b) { - using ElementAccumulator = float; - using ElementCompute = float; - using ElementBlockScale = float; - - using ElementA = cutlass::float_e4m3_t; - using LayoutA = cutlass::layout::RowMajor; - constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - - using ElementB = cutlass::float_e4m3_t; - using LayoutB = cutlass::layout::ColumnMajor; - constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - - using ElementC = void; - using LayoutC = cutlass::layout::RowMajor; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using ElementD = OutType; - using LayoutD = cutlass::layout::RowMajor; - constexpr int AlignmentD = AlignmentC; - - using ScaleTileShape = Shape<_1, _128, _128>; - using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{})); - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); - - using ArchTag = cutlass::arch::Sm90; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT; - - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - TileShape, - ClusterShape, - EpilogueTileType, - ElementAccumulator, - ElementCompute, - ElementC, - LayoutC, - AlignmentC, - ElementD, - LayoutD, - AlignmentD, - EpilogueSchedule, - StoreEpilogueCompute>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - cute::tuple, - AlignmentA, - ElementB, - cute::tuple, - AlignmentB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, // Indicates ProblemShape - CollectiveMainloop, - CollectiveEpilogue, - SchedulerType>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - Gemm gemm_op; - - int m = a.size(0); - int k = a.size(1); - int n = b.size(1); - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto o_ptr = static_cast(out.data_ptr()); - - auto a_s_ptr = static_cast(scales_a.data_ptr()); - auto b_s_ptr = static_cast(scales_b.data_ptr()); - - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; - - StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); - StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); - StrideC stride_c; - StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); - - LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); - LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, stride_a, b_ptr, stride_b, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb}; - typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d}; - - typename Gemm::Arguments args = { - cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - mainloop_args, - epilogue_args, - }; - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) - - auto status = gemm_op.run(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) -} - template < typename OutType, typename MmaTileShape, @@ -297,27 +171,6 @@ void launch_sm100_fp8_blockwise_scaled_mm( TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) } -template -void sm90_fp8_blockwise_dispatch_shape( - torch::Tensor& out, - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& scales_a, - const torch::Tensor& scales_b) { - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; - - auto k = a.size(1); - auto n = b.size(1); - if (k > 3 * n) { - launch_sm90_fp8_blockwise_scaled_mm( - out, a, b, scales_a, scales_b); - } else { - launch_sm90_fp8_blockwise_scaled_mm( - out, a, b, scales_a, scales_b); - } -} - template void sm100_fp8_blockwise_dispatch_shape( torch::Tensor& out, @@ -394,10 +247,10 @@ torch::Tensor fp8_blockwise_scaled_mm( if (sm_version == 90) { torch::Tensor scales_b_contiguous = scales_b.contiguous(); if (out_dtype == torch::kBFloat16) { - sm90_fp8_blockwise_dispatch_shape( + cutlass_gemm_blockwise_sm90_fp8_dispatch( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); } else { - sm90_fp8_blockwise_dispatch_shape( + cutlass_gemm_blockwise_sm90_fp8_dispatch( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); } return out_padded.slice(0, 0, original_rows);