diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt old mode 100644 new mode 100755 index 929a60d3b..7a8b8601d --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -195,6 +195,7 @@ set(SOURCES "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" + "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc old mode 100644 new mode 100755 index c0ab4e0e2..7c0df156d --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -150,6 +150,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "n_share_experts_fusion, float routed_scaling_factor) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); + m.def( + "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " + "stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor " + "expert_offsets) -> ()"); + m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm); /* * From csrc/speculative diff --git a/sgl-kernel/csrc/moe/cutlass_moe_helper.cu b/sgl-kernel/csrc/moe/cutlass_moe_helper.cu new file mode 100644 index 000000000..e8af2093e --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe_helper.cu @@ -0,0 +1,142 @@ +#pragma once + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +template < + typename ElementAB, + typename ElementC, + typename ElementAccumulator, + typename LayoutSFA, + typename LayoutSFB, + typename ScaleConfig> +__global__ void get_group_gemm_starts( + int32_t* expert_offsets, + ElementAB** a_offsets, + ElementAB** b_offsets, + ElementC** out_offsets, + ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, + ElementAB* a_base_as_int, + ElementAB* b_base_as_int, + ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, + LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, + int* problem_sizes, + int* problem_sizes_transpose, + bool transpose = false) { + int expert_id = threadIdx.x; + + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + + int m = problem_sizes[expert_id * 3]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + if (transpose) { + problem_sizes_transpose[expert_id * 3] = n; + problem_sizes_transpose[expert_id * 3 + 1] = m; + problem_sizes_transpose[expert_id * 3 + 2] = k; + } + + int32_t expert_offset = expert_offsets[expert_id]; + int a_stride = 0; + int b_stride = 0; + int a_scale_stride = 0; + int b_scale_stride = 0; + if (!transpose) { + a_stride = expert_offset * k; + b_stride = expert_id * k * n; + a_scale_stride = expert_offset * k / 128; + b_scale_stride = expert_id * k * n / 128 / 128; + } else { + a_stride = expert_id * k * n; + b_stride = expert_offset * k; + a_scale_stride = expert_id * k * n / 128 / 128; + b_scale_stride = expert_offset * k / 128; + } + a_offsets[expert_id] = a_base_as_int + a_stride; + b_offsets[expert_id] = b_base_as_int + b_stride; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = a_scales_base_as_int + a_scale_stride; + b_scales_offsets[expert_id] = b_scales_base_as_int + b_scale_stride; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + if (!transpose) { + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + } else { + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(n, m, k, 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(n, m, k, 1)); + } +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), \ + static_cast(problem_sizes_transpose.data_ptr()), \ + transpose); \ + } + +namespace { +template +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& layout_sfa, + torch::Tensor const& layout_sfb, + torch::Tensor const& problem_sizes, + torch::Tensor& problem_sizes_transpose, + bool transpose = false) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); + TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); + + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} +} // namespace diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu new file mode 100644 index 000000000..d85293dce --- /dev/null +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -0,0 +1,386 @@ +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass_moe_helper.cu" +#include "utils.h" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; +template +void launch_sm100_fp8_blockwise_scaled_group_mm( + torch::Tensor& out_ptrs, + const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, + const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + + // Alignment constraints + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + // For fp8 block scale. + // using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; using LayoutSFA = + // decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + void, + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentC, + typename ScheduleConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + int num_experts = (int)expert_offsets.size(0); + // Create an instance of the GEMM + Gemm gemm_op; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(stride_a.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(stride_b.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + cutlass::KernelHardwareInfo hw_info; + + hw_info.device_id = 0; + hw_info.sm_count = 1; + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, + nullptr, + static_cast(stride_c.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(stride_c.data_ptr())}; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + // Use prob_shape in the GEMM arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info}; + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void sm100_fp8_blockwise_group_mm_dispatch_shape( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { + // Check the first matrix size to decide on the configuration + // Assuming all matrices in the group have similar size characteristics + // bool use_small_config = a[0].size(0) <= 128; + struct MMALargeConfig { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = + cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + + struct MMASmallConfig { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _16, _128>; + using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = + cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + int num_experts = (int)expert_offsets.size(0); + torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); + torch::Tensor workspace = torch::empty(100, options_int); + torch::Tensor output_t = output.t(); + torch::Tensor a_t = a.t(); + torch::Tensor b_t = b.transpose(1, 2); + torch::Tensor scales_a_t = scales_a.t(); + torch::Tensor scales_b_t = scales_b.transpose(1, 2); + + if (a.size(0) <= 512) { + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + b_t, + a_t, + output_t, + scales_b_t, + scales_a_t, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose, + true); + launch_sm100_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes_transpose, + expert_offsets, + workspace); + output = output_t.t(); + } else { + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm100_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } +} + +/** + * @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs, + * with per-block scaling. + * + * This function dispatches to hardware-specific implementations (e.g., SM100 FP8) + * to compute: + * C_i = scale_a[i] * A_i * scale_b[i] * B_i + * for each expert group `i`, using input `problem_sizes` and `expert_offsets` + * to describe the individual matrix dimensions and their offsets. + * + * Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication. + * The output tensor is written with bfloat16 or half precision. + * + * @param output Output tensor (must be of type bfloat16 or half). + * @param a Input tensor A (must be kFloat8_e4m3fn). + * @param b Input tensor B (must be kFloat8_e4m3fn). + * @param scales_a Scaling factors for tensor A, float32 per expert group. + * @param scales_b Scaling factors for tensor B, float32 per expert group. + * @param stride_a Stride information for tensor A (int32). + * @param stride_b Stride information for tensor B (int32). + * @param stride_c Stride information for output tensor C (int32). + * @param layout_sfa Layout descriptor for A (int32), e.g., row-major/column-major. + * @param layout_sfb Layout descriptor for B (int32). + * @param problem_sizes 2D int32 tensor of shape (num_experts, 3), specifying (M, N, K) + * for each grouped matrix multiplication problem. + * @param expert_offsets 1D int32 tensor of size (num_experts), used to index into + * the grouped input tensors for dispatch. + * @note Performance Optimization: + * If the batch size (a.size(0)) is smaller than 512, the implementation + * will internally transpose input matrices to align with the optimal memory access + * pattern for better GPU efficiency. This transformation is done within the kernel. + */ +void fp8_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); + TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn"); + TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn"); + TORCH_CHECK( + output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf, + "output must be bfloat16 or half"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be float32"); + TORCH_CHECK(stride_a.scalar_type() == torch::kInt64, "stride_a must be int64"); + TORCH_CHECK(stride_b.scalar_type() == torch::kInt64, "stride_b must be int64"); + TORCH_CHECK(stride_c.scalar_type() == torch::kInt64, "stride_c must be int64"); + TORCH_CHECK(layout_sfa.scalar_type() == torch::kInt32, "layout_sfa must be int32"); + TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32"); + TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32"); + + bool can_implement = false; + auto sm_version = getSMVersion(); + +#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +#if defined CUDA_VERSION && CUDA_VERSION >= 12080 + if (sm_version == 100) { + if (output.scalar_type() == torch::kBFloat16) { + sm100_fp8_blockwise_group_mm_dispatch_shape( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets); + } else { + sm100_fp8_blockwise_group_mm_dispatch_shape( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets); + } + can_implement = true; + } +#endif +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h old mode 100644 new mode 100755 index 10df9d1c7..3c906f587 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -209,6 +209,20 @@ std::vector moe_fused_gate( int64_t n_share_experts_fusion, double routed_scaling_factor); +void fp8_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets); + /* * From csrc/speculative */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py old mode 100644 new mode 100755 index a6338ee5a..e6ba19e0f --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -41,7 +41,12 @@ from sgl_kernel.gemm import ( sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, ) -from sgl_kernel.moe import moe_align_block_size, moe_fused_gate, topk_softmax +from sgl_kernel.moe import ( + fp8_blockwise_scaled_grouped_mm, + moe_align_block_size, + moe_fused_gate, + topk_softmax, +) from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py old mode 100644 new mode 100755 index afabc44f9..a5e0b3668 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -60,3 +60,33 @@ def moe_fused_gate( n_share_experts_fusion, routed_scaling_factor, ) + + +def fp8_blockwise_scaled_grouped_mm( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, +): + torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + ) diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py new file mode 100755 index 000000000..27ef9c021 --- /dev/null +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -0,0 +1,148 @@ +import random + +import pytest +import torch +from sgl_kernel import fp8_blockwise_scaled_grouped_mm + + +def cdiv(a: int, b: int) -> int: + return -(a // -b) + + +def scale_shape(shape, group_shape): + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def baseline_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], +) -> torch.Tensor: + + def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) + .flatten(i, i + 1) + ) + return t + + scale_a = group_broadcast(scale_a, a.shape) + scale_b = group_broadcast(scale_b, b.shape) + + return torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ).to(out_dtype) + + +@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) +def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): + device = "cuda" + alignment = 16 + n_g = alignment * random.randint(1, 5) * 128 + k_g = alignment * random.randint(1, 5) * 128 + + scale_a_group_shape = (1, 128) + scale_b_group_shape = (128, 128) + + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) + layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) + layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) + + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + baseline_tensors = [] + + for g in range(num_experts): + m_g = alignment * random.randint(1, 64) + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + + scale_a_shape = scale_shape(a_g.shape, scale_a_group_shape) + scale_b_shape = scale_shape(b_g.shape, scale_b_group_shape) + + a_scales_tensors.append(torch.randn(scale_a_shape, device=device) * 0.001) + b_scales_tensors.append(torch.randn(scale_b_shape, device=device) * 0.001) + + baseline = baseline_scaled_mm( + a_g, b_g, a_scales_tensors[-1], b_scales_tensors[-1], out_dtype + ) + baseline_tensors.append(baseline) + + a_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_stack = torch.empty( + (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) + + for g in range(num_experts): + a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] + b_stack[g] = b_tensors[g].t() + b_stack = b_stack.transpose(1, 2) + + a_scale_stack = torch.empty( + (expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32 + ) + b_scale_stack = torch.empty( + (num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32 + ) + + for g in range(num_experts): + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g] + b_scale_stack[g] = b_scales_tensors[g].t() + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + a_strides = torch.full( + (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64 + ) + c_strides = torch.full( + (num_experts,), c_out.stride(0), device=device, dtype=torch.int64 + ) + + fp8_blockwise_scaled_grouped_mm( + c_out, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + a_strides, + a_strides, + c_strides, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets[:-1], + ) + + for g in range(num_experts): + baseline = baseline_tensors[g] + actual = c_out[expert_offsets[g] : expert_offsets[g + 1]] + torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=5e-4) + print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK") + + +if __name__ == "__main__": + pytest.main([__file__])