From da3890e82a97e9b1f970133c0fef6341a8bbf791 Mon Sep 17 00:00:00 2001 From: SijiaYang Date: Sat, 5 Jul 2025 11:50:12 +0800 Subject: [PATCH] [1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772) Signed-off-by: yangsijia.614 Co-authored-by: yicwang --- sgl-kernel/CMakeLists.txt | 3 + sgl-kernel/csrc/common_extension.cc | 19 + .../detail/collective/mixed_input_utils.hpp | 482 ++++++ .../sm90_gmma_builder_mixed_input.inl | 278 +++ .../collective_builder_mixed_input.hpp | 52 + .../collective_mma_array_mixed_input.hpp | 53 + ...a_gmma_rs_warpspecialized_mixed_input_.hpp | 1535 +++++++++++++++++ .../moe/cutlass_moe/w4a8/scaled_mm_entry.cu | 91 + .../w4a8/w4a8_get_group_starts.cuh | 92 + .../cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu | 240 +++ .../cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh | 276 +++ .../moe/cutlass_moe/w4a8/w4a8_moe_data.cu | 79 + sgl-kernel/include/sgl_kernel_ops.h | 29 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/cutlass_moe.py | 112 ++ sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py | 260 +++ 16 files changed, 3602 insertions(+) create mode 100644 sgl-kernel/csrc/cutlass_extensions/detail/collective/mixed_input_utils.hpp create mode 100644 sgl-kernel/csrc/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl create mode 100644 sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp create mode 100644 sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp create mode 100644 sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp create mode 100644 sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu create mode 100644 sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh create mode 100644 sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu create mode 100644 sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh create mode 100644 sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu create mode 100644 sgl-kernel/python/sgl_kernel/cutlass_moe.py create mode 100644 sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 1bc289519..57c444e76 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -249,6 +249,9 @@ set(SOURCES "csrc/speculative/speculative_sampling.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/kvcacheio/transfer.cu" + "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" + "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" + "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/common_extension.cc" "csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index b6a22152a..bdec3cf1a 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "int num_layers) -> ()"); m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); + /* + * From csrc/moe/cutlass_moe/w4a8 + */ + m.def( + "get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " Tensor! input_permutation, " + " Tensor! output_permutation, int num_experts, " + " int n, int k) -> ()"); + m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data); + + m.def( + "cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, " + " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor d_strides, Tensor s_strides," + " int chunk_size, int topk) -> ()"); + m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm); + /* * From FlashInfer */ diff --git a/sgl-kernel/csrc/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/sgl-kernel/csrc/cutlass_extensions/detail/collective/mixed_input_utils.hpp new file mode 100644 index 000000000..fcba4d40c --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/arch/copy_sm90.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/util/type_traits.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective::detail { + +template +struct MixedGroupedGemmInputUtils { + private: + using KernelSchedule = typename Collective::KernelSchedule; + using ConversionMode = typename Collective::ConversionMode; + using SmemLayoutA = typename Collective::SmemLayoutA; + using SmemLayoutB = typename Collective::SmemLayoutB; + using SmemLayoutScale = typename Collective::SmemLayoutScale; + using SwappedElementA = typename Collective::SwappedElementA; + using SwappedElementB = typename Collective::SwappedElementB; + using RealSwappedElementA = typename Collective::RealSwappedElementA; + using RealSwappedElementB = typename Collective::RealSwappedElementB; + using ElementScale = typename Collective::ElementScale; + using ElementZero = typename Collective::ElementZero; + using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale; + static constexpr auto KernelConversionMode = Collective::KernelConversionMode; + static constexpr auto ModeHasScales = Collective::ModeHasScales; + static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + + public: + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto elements_per_smem_zero() { + if constexpr ( + KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t compute_tma_transaction_bytes_mk() { + return cutlass::bits_to_bytes( + size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes( + size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t compute_tma_transaction_bytes_extra() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes( + size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * + static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes( + size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * + static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE static void copy_tensors_MK( + SmemTiledCopyA const& smem_tiled_copy_A, + TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, + int k_block, + int read_stage) { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block)); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template < + class EngineIn, + class LayoutIn, + class EngineOut, + class LayoutOut, + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries + Tensor const& src, + Tensor&& dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + lookup_table_convert(src, dst, scales_neg, scales_pos); + } + template + CUTLASS_DEVICE static void lookup_table_convert( + Tensor const& src, + Tensor& dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + constexpr int N = cute::cosize(LayoutIn{}); + static_assert(N == 4 || N == 8); + static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale."); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + // View the input as reg + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + + // Determines if to get from the signed or unsigned candidates + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(sign) + : "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut)); + sign = sign >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = src_reg & 0x77777777; + Tensor scales_neg_ = cute::filter(scales_neg); + Tensor scales_pos_ = cute::filter(scales_pos); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_(i)); + auto&& scale_pos_ = reinterpret_cast const&>(scales_pos_(i)); + asm volatile( + "{\n" + " .reg .b32 pos, neg ;\n" + " prmt .b32 neg, %3, %4, %1 ;\n" + " prmt .b32 pos, %5, %6, %1 ;\n" + " prmt .b32 %0, pos, neg, %2 ;\n" + "}\n" + : "=r"(r[i]) + : "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1])); + } + } + + /// Utilities to dequantize A. + template + CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) { + static_assert( + shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, + "At least 4 adjacent weights in a thread must share the same scale."); + } + template + CUTLASS_DEVICE static void static_check_scale(Tensor const& tensor) { + static_check_scale(flatten(Layout{})); + } + + template + CUTLASS_DEVICE static void dequantize_A_kblock( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + cute::tuple& partitioned_extra_info, + int const k_block) { + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + Tensor src = tCrA_load(_, _, k_block); + Tensor dst = tCrA_mma(_, _, k_block); + + CUTE_STATIC_ASSERT_V( + size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory"); + // try to make the size of the first mode equal to 32bit + int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v)); + Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int{})); + Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int{})); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } + } else if constexpr (UseScaleLookupTable) { + constexpr int num_elements = decltype(size(src))::value; + static_assert( + is_same_v, + "Lookup table only supports int4 being the quant type now."); + static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); + static_assert( + num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); + + Tensor tCrS_neg = cute::get<1>(partitioned_extra_info); + auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed + Tensor scales_neg = tCrS_neg(_, _, k_block); + Tensor scales_pos = tCrS_pos(_, _, k_block); + CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg)); + + static_check_scale(scales_neg); + static_check_scale(scales_pos); + Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int{})); + Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int{})); + + if (k_block == 0) { + Tensor scales_neg_vm_ = filter(scales_neg_vm); + Tensor scales_pos_vm_ = filter(scales_pos_vm); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); + auto&& scale_pos_ = reinterpret_cast&>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3 .b32 %0, %2, %4, %5, %6;\n" + " xor .b32 %1, %3, %5; \n" + "}\n" + : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i)); + } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) *= scales_vm(j, i); + } + } + } else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) *= scales_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int{})); + Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i); + } + } + } else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + + template + CUTLASS_DEVICE static void convert_A_kblock( + Tensor const& tCrA_load, Tensor& tCrA_mma, int const k_block) { + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + + Tensor src = tCrA_load(_, _, k_block); + Tensor dst = tCrA_mma(_, _, k_block); + + CUTE_STATIC_ASSERT_V( + size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory"); + // try to make the size of the first mode equal to 32bit + int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v)); + Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int{})); + Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int{})); + + // KernelConversionMode == ConversionMode::DirectConvert + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } + } + + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE static auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = + make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = + make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE static auto + partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (UseScaleLookupTable) { + Tensor sS = + make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout()); + + return cute::make_tuple(tCsS, tCrS); + } else if constexpr (ModeHasScales) { + Tensor sS = + make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE static auto retile_extra_mma_info( + TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, int const warp_group_thread_idx) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } +}; + +} // namespace cutlass::gemm::collective::detail diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl new file mode 100644 index 000000000..db1fdf1e7 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl @@ -0,0 +1,278 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/collective/collective_builder_decl.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS +template < + class ElementA_, + class GmemLayoutATag_, + int AlignmentA, + class ElementB_, + class GmemLayoutBTag_, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType> +struct CollectiveBuilderMixedInput< + arch::Sm90, + arch::OpClassTensorOp, + ElementA_, + GmemLayoutATag_, + AlignmentA, + ElementB_, + GmemLayoutBTag_, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (detail::is_use_rmem_A() || + // ConvertAndScale and ConvertAndScaleWithZero + cute::is_tuple::value || cute::is_tuple::value || + // DirectConvert + sizeof_bits::value != sizeof_bits::value)>> { + private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>; + static constexpr bool NeitherIsTuple = !cute::is_tuple::value && !cute::is_tuple::value; + // Determine if mixed input types. + static constexpr bool IsMixedInput = cute::sizeof_bits_v> != + cute::sizeof_bits_v>; + static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v< + KernelScheduleType, + KernelPtrArrayTmaWarpSpecializedCooperative, + KernelPtrArrayTmaWarpSpecializedPingpong>; + static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported."); + + public: + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>; + + static_assert( + !IsMixedInput || (cute::is_tuple::value ^ cute::is_tuple::value || + (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value))), + "Either A OR B must be a tuple or the widths of A and B must be different."); + + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; + + template + static auto get_stride(T const& t) { + if constexpr (not cute::is_layout>::value) { + return t; + } else { + if constexpr (cute::is_pointer_v) { + return &cute::stride(*t); + } else { + return cute::stride(t); + } + } + } + + using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{})); + using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{})); + + using ElementPairA = + cute::conditional_t, ElementA_>; + using ElementPairB = + cute::conditional_t, ElementB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert( + detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the + // operands. + static constexpr bool SwapAB = + IsMixedInput ? !IsATransformed : detail::is_swapAB(); + static constexpr bool IsWarpSpecializedTransposeB = + detail::is_warpspecialized_transpose_B(); + static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + // For fp32 types, map to tf32 MMA value type. + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // Handle mixed dtypes and MMA. + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + using RealElementAMma = cute::conditional_t; + // Always the same for element B. + using RealElementBMma = RealElementB; + + static_assert( + !IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits::value == 16, + "Mixed input GEMM does not support MN major layout except for 16bit"); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_any_of_v< + KernelScheduleType, + KernelTmaWarpSpecializedCooperative, + KernelPtrArrayTmaWarpSpecializedCooperative>, + Layout>, + Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector< + RealElementAMma, + RealElementBMma, + ElementAccumulator, + TileShape_MNK, + GMMA::Major::K, + GMMA::Major::K>(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector< + GmmaMajorA, + ElementAMma, + decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), + IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector< + GmmaMajorB, + ElementBMma, + decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), + IsWarpSpecializedTransposeB>()); + + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{}); + static constexpr int SmemAlignment = static_cast(cute::max(SmemAlignmentA, SmemAlignmentB)); + + // Handle mixed dtype array GEMM's size of tensor map storage. + static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int PipelineStages = + IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input< + Sm90ReducedSmemCapacityBytes, + RealElementA, + RealElementB, + ElementScale, + ElementZero, + TileShape_MNK, + StageCountType::bytes, + SmemAlignment>(StageCountType{}) + : detail::compute_stage_count_or_override_single_affine_transformed_input< + detail::sm90_smem_capacity_bytes, + RealElementA, + RealElementB, + ElementScale, + ElementZero, + TileShape_MNK, + StageCountType::bytes, + SmemAlignment>(StageCountType{})) + : detail::compute_stage_count_or_override< + detail::sm90_smem_capacity_bytes, + ElementAMma, + ElementBMma, + TileShape_MNK, + StageCountType::bytes, + SmemAlignment>(StageCountType{}); + + using DispatchPolicy = cute::conditional_t< + IsMixedInput, + cute::conditional_t< + IsArrayOfPointersGemm, + MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput, + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>, + MainloopSm90TmaGmmaRmemAWarpSpecialized>; + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideA = cute::conditional_t< + cute::is_layout>::value, + GmemLayoutATag_, + TagToStrideA_t>; + using StrideB = cute::conditional_t< + cute::is_layout>::value, + GmemLayoutBTag_, + TagToStrideB_t>; + + using CollectiveOp = CollectiveMmaArrayMixedInput< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StrideA, + ElementPairB, + StrideB, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity>; + + static_assert( + SmemAlignment == static_cast(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB))); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp new file mode 100644 index 000000000..8288ace95 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void> +struct CollectiveBuilderMixedInput { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp new file mode 100644 index 000000000..08afdffd4 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomA, + class SmemCopyAtomA, + class TransformA, + class GmemTiledCopyB, + class SmemLayoutAtomB, + class SmemCopyAtomB, + class TransformB> +struct CollectiveMmaArrayMixedInput { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp new file mode 100644 index 000000000..6e1a01e22 --- /dev/null +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -0,0 +1,1535 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass_extensions/detail/collective/mixed_input_utils.hpp" + +#define GROUP_SIZE 128 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule_, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMmaArrayMixedInput< + MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + public: + enum class ConversionMode { DirectConvert, ConvertAndScale, ConvertAndScaleWithZero }; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; + using TileShape = TileShape_; + using KernelSchedule = KernelSchedule_; + + private: + template + friend struct detail::MixedGroupedGemmInputUtils; + using CollectiveType = CollectiveMma< + DispatchPolicy, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_>; + using Utils = detail::MixedGroupedGemmInputUtils; + + // + // Type Aliases + // + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + + public: + static_assert( + cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in " + "[] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is + // void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + using StrideScale = cute::Stride, int64_t, int64_t>; + using NonVoidStrideScale = + cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert( + (IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || + is_layout::value)) || + (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || + is_layout::value)), + "The transformed type must be K-major."); + + static_assert( + (IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || + ((cutlass::gemm::detail::is_k_major() || is_layout::value || + is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value || + is_layout::value)), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert( + cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + using SwappedSmemLayoutAtomA = cute::conditional_t; + using SwappedSmemLayoutAtomB = cute::conditional_t; + using SwappedSmemCopyAtomA = cute::conditional_t; + using SwappedSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using SwappedTransformA = cute::conditional_t; + using SwappedTransformB = cute::conditional_t; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t>; // in case we have array. translating to uint + // to satisfy tma descriptor's specialization + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + /// Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(detail::get_smem_layout( + SwappedSmemLayoutAtomA{}, select<0, 2>(TileShape{}), InternalSwappedStrideA{})); + using SmemLayoutB = decltype(detail::get_smem_layout( + SwappedSmemLayoutAtomB{}, select<1, 2>(TileShape{}), InternalSwappedStrideB{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t< + ::cutlass::gemm::detail::is_major<0, NonVoidStrideScale>(), + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert( + not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = + KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage { + CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine> smem_A; + CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + struct TensorMapStorage { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_scale; + cute::TmaDescriptor smem_tensormap_zero; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // kernel Arguments + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementScale const** ptr_S = nullptr; + NonVoidStrideScale const* dS{}; + int chunk_size = 0; + ElementZero const** ptr_Z = nullptr; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = + decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideA{}, int32_t(0)), InternalSwappedStrideA{})); + using LayoutB = + decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideB{}, int32_t(0)), InternalSwappedStrideB{})); + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor( + detail::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), + NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor( + detail::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), + NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + void* tensormaps; + SwappedElementA const** ptr_A; + SwappedStrideA ptr_dA; + SwappedElementB const** ptr_B; + SwappedStrideB ptr_dB; + NonVoidElementScale const** ptr_S; + NonVoidStrideScale const* dS; + NonVoidElementZero const** ptr_Z; + int64_t scale_k; + int chunk_size; + int reload_factor = (chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments(ProblemShape problem_shapes, Arguments const& args, void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + + if constexpr (SwapAB) { + init_M = get<1>(init_shape); + init_N = get<0>(init_shape); + } + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t mock_L = 1; + SwappedElementA const* ptr_A_first_batch; + SwappedElementB const* ptr_B_first_batch; + SwappedStrideA ptr_dA; + SwappedStrideB ptr_dB; + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + + if constexpr (not SwapAB) { + ptr_A_first_batch = reinterpret_cast(args.ptr_A); + ptr_B_first_batch = reinterpret_cast(args.ptr_B); + } else { + ptr_A_first_batch = reinterpret_cast(args.ptr_B); + ptr_B_first_batch = reinterpret_cast(args.ptr_A); + } + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + if constexpr (not SwapAB) { + ptr_dA = args.dA; + ptr_dB = args.dB; + } else { + ptr_dA = args.dB; + ptr_dB = args.dA; + } + dA = InternalSwappedStrideA{}; + if constexpr (is_layout::value) { + dA = make_layout( + transform_leaf( + dA.shape(), + [](auto x) { + if constexpr (not is_static_v) { + return static_cast(1); + } else { + return x; + } + }), + dA.stride()); + } + dB = InternalSwappedStrideB{}; + } else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + if constexpr (not SwapAB) { + dA = args.dA; + dB = args.dB; + } else { + dA = args.dB; + dB = args.dA; + } + ptr_dA = SwappedStrideA{}; + ptr_dB = SwappedStrideB{}; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, detail::get_gmem_layout(make_shape(init_M, init_K, mock_L), dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, detail::get_gmem_layout(make_shape(init_N, init_K, mock_L), dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + void* tensormaps = workspace; + auto args_setup = + [&](auto ptr_A, auto ptr_B, int64_t scale_k = 0, int chunk_size = 0, int reload_factor = 1) -> Params { + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tma_load_scale, + tma_load_zero, + tensormaps, + reinterpret_cast(ptr_A), + ptr_dA, + reinterpret_cast(ptr_B), + ptr_dB, + reinterpret_cast(args.ptr_S), + args.dS, + reinterpret_cast(args.ptr_Z), + scale_k, + chunk_size, + reload_factor, + dA, + dB}; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A) : args_setup(args.ptr_A, args.ptr_B); + } else if constexpr (ModeHasScales) { + auto fake_scale_k = 1; + ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); + StrideScale dS{}; + Tensor tensor_scale = + make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M, fake_scale_k, mock_L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return SwapAB ? args_setup( + args.ptr_B, + args.ptr_A, + fake_scale_k, + args.chunk_size, + (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup( + args.ptr_A, + args.ptr_B, + fake_scale_k, + args.chunk_size, + (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = reinterpret_cast(args.ptr_Z); + Tensor tensor_zero = + make_tensor(detail::get_logical_ptr(ptr_Z), make_layout(make_shape(init_M, fake_scale_k, mock_L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return SwapAB ? args_setup( + args.ptr_B, + args.ptr_A, + fake_scale_k, + args.chunk_size, + (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup( + args.ptr_A, + args.ptr_B, + fake_scale_k, + args.chunk_size, + (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + + } else { + static_assert( + cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + // Calculating workspace size + auto calculate_workspace_size = [SizeOfCuTensorMap, sm_count](uint32_t num_input_tensors) { + return num_input_tensors * SizeOfCuTensorMap * sm_count; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return calculate_workspace_size(2); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, + // followed by scale tensormap copies + return calculate_workspace_size(3); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, + // followed by scale and zeros tensormap copies + return calculate_workspace_size(4); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in get_workspace_size."); + } + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape problem_shapes, Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto get_stride = [](auto stride) { + if constexpr (cute::is_pointer_v>) { + return *stride; + } else { + return stride; + } + }; + auto dA = get_stride(args.dA); + auto dB = get_stride(args.dB); + implementable = implementable && cutlass::detail::check_alignment( + detail::get_gmem_layout(cute::make_shape(M, K, L), dA)); + implementable = implementable && cutlass::detail::check_alignment( + detail::get_gmem_layout(cute::make_shape(N, K, L), dB)); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = (K + args.chunk_size - 1) / args.chunk_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.chunk_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = + TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( + shape(detail::get_gmem_layout(make_shape(M, K, mock_L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( + shape(detail::get_gmem_layout(make_shape(N, K, mock_L), mainloop_params.dB))); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } else if constexpr (ModeHasScales) { + // The real scale_k that actually works + // auto scale_k = K / mainloop_params.chunk_size; + auto scale_k = K / GROUP_SIZE; + + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template + CUTLASS_DEVICE void load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); + static_assert(sizeof...(TMs) == 2, "Direct convert needs two tensormaps"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); + static_assert(sizeof...(TMs) == 3, "Scaled convert needs three tensormaps"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof...(Ts) == 4, "Scaled and zero convert needs four inputs"); + static_assert(sizeof...(TMs) == 4, "Scaled and zero convert needs four tensormaps"); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + auto extra_input_partitions = Utils::partition_extra_tma_inputs( + mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy( + mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy( + mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } else if constexpr (ModeHasScales) { + // scale copy + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma + // transaction bytes on the fly. We must do a ceiling divide here to correctly handle with chunk_size == K. In + // that case, we don't require that K is a multiple of the threadblock tile K + const int scale_load_k = *k_tile_iter / 1; + // const int scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when + // chunk_size == K. + if (cute::elect_one_sync()) { + copy( + mainloop_params.tma_load_scale.with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_s), + tSgS(_, _, _, scale_load_k), + tSsS(_, _, _, write_stage)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // zero copy + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) { + copy( + mainloop_params.tma_load_zero.with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_s), + tZgZ(_, _, _, scale_load_k), + tZsZ(_, _, _, write_stage)); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2."); + static_assert( + !cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert( + cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert( + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_, _, Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = [&] { + if constexpr (not is_layout::value) { + // Make register tensor with MMA layout + return make_fragment_like(tCrA_mma); + } else { + // Make register tensor matching smem layout, converter will take care of de-swizzling + return make_tensor_like(tCsA(_, _, _, Int<0>{})); + } + }(); + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = + Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + multiply_add fma; + + constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())(); + constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE; + cute::array intermediate_array; + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7); + static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK"); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // First k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + + Utils::copy_tensors_MK( + smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + if (K_BLOCK_MAX > 1) { + Utils::copy_tensors_MK( + smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + + // src: tCrA_load, dst: tCrA_mma + Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_UNROLL + for (int mma_id = 0; mma_id < NumMMAsPerChunk; ++mma_id) { + int k_block = chunk_id * NumMMAsPerChunk + mma_id; + + warpgroup_arrive(); + + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate_array[chunk_id]); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + k_block + 2, + read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { + warpgroup_fence_operand(intermediate_array[chunk_id_]); + + // Apply the group-wise scaling + // tCrS ((4, _2, _2), MMA_M, _1) + // accum ((2, _2, _2), MMA_M, _1) + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) { + for (int m = 0; m < size<0, 1>(accum); m++) { + for (int n = 0; n < size<0, 2>(accum); n++) { + for (int e = 0; e < size<0, 0>(accum); e++) { + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + + if (chunk_id_ == 0) { + accum(accum_coord) = + intermediate_array[chunk_id_](accum_coord) * static_cast(tCrS(scale_coord)[0]); + } else { + accum(accum_coord) = + fma(intermediate_array[chunk_id_](accum_coord), + static_cast(tCrS(scale_coord)[chunk_id_]), + accum(accum_coord)); + } + } + } + } + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first + // mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + 0, + smem_pipe_read.index()); + + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + 1, + smem_pipe_read.index()); + + warpgroup_wait(); + Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + } + } + + if (k_tile_count == 0) { + return; + } + + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_UNROLL + for (int mma_id = 0; mma_id < NumMMAsPerChunk; ++mma_id) { + int k_block = chunk_id * NumMMAsPerChunk + mma_id; + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate_array[chunk_id]); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can + // release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + // The last k_block + + CUTLASS_PRAGMA_UNROLL + for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { + warpgroup_fence_operand(intermediate_array[chunk_id_]); + + // Apply the group-wise scaling + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) { + for (int m = 0; m < size<0, 1>(accum); m++) { + for (int n = 0; n < size<0, 2>(accum); n++) { + for (int e = 0; e < size<0, 0>(accum); e++) { + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + + accum(accum_coord) = + fma(intermediate_array[chunk_id_](accum_coord), + static_cast(tCrS(scale_coord)[chunk_id_]), + accum(accum_coord)); + } + } + } + } + } + + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // copy scales when passing k_block=0 + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + 0, + smem_pipe_read.index()); + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + 1, + smem_pipe_read.index()); + Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + } else { + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + k_block + 2, + read_stage); + } + Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); + } + } + } + } + + { + // + // Last k tile + // + Tensor intermediate = make_fragment_like(accum); + + int read_stage = smem_pipe_read.index(); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK( + smem_tiled_copy_A, + tCsA, + tCrA_copy_view, + partitioned_extra_info, + copy_partitions_extra_info, + k_block + 2, + read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); + } + + if ((k_block + 1) % NumMMAsPerChunk == 0) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(intermediate); + + // Apply the group-wise scaling + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) { + for (int m = 0; m < size<0, 1>(accum); m++) { + for (int n = 0; n < size<0, 2>(accum); n++) { + for (int e = 0; e < size<0, 0>(accum); e++) { + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + int scale_idx = k_block / NumMMAsPerChunk; + + accum(accum_coord) = fma( + intermediate(accum_coord), static_cast(tCrS(scale_coord)[scale_idx]), accum(accum_coord)); + } + } + } + } + } + } + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto tensormaps_init( + Params const& mainloop_params, TensorMapStorage& shared_tensormaps, int32_t sm_count, int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_scale = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_zero = &gmem_tensormap[sm_idx + 3 * sm_count]; + + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pS_tensormap), recast(sS_tensormap)); + } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor pZ_tensormap = make_tensor(mainloop_params.tma_load_zero.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sZ_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_zero), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pZ_tensormap), recast(sZ_tensormap)); + } + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + + __syncwarp(); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale, tma_desc_zero); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_B, mainloop_params.ptr_B[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]); + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_address."); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE void tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1, 1, 1, 1, 1}; + cute::array prob_stride_A = {0, 0, 0, 0, 0}; + cute::array prob_shape_B = {1, 1, 1, 1, 1}; + cute::array prob_stride_B = {0, 0, 0, 0, 0}; + cute::array prob_shape_scale = {1, 1, 1, 1, 1}; + cute::array prob_stride_scale = {0, 0, 0, 0, 0}; + cute::array prob_shape_zero = {1, 1, 1, 1, 1}; + cute::array prob_stride_zero = {0, 0, 0, 0, 0}; + + SwappedElementA const* ptr_A = nullptr; + Tensor tensor_a = + make_tensor(ptr_A, detail::get_gmem_layout(make_shape(M, K, Int<1>{}), mainloop_params.ptr_dA[next_group])); + + SwappedElementB const* ptr_B = nullptr; + Tensor tensor_b = + make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N, K, Int<1>{}), mainloop_params.ptr_dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + NonVoidElementScale const* ptr_S = nullptr; + // auto scale_k = K / mainloop_params.chunk_size; + auto scale_k = K / GROUP_SIZE; + Tensor tensor_scale = + make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride( + mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = nullptr; + // auto scale_k = K / mainloop_params.chunk_size; + auto scale_k = K / GROUP_SIZE; + Tensor tensor_zero = + make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride( + mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero); + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_scale) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_zero) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_B, prob_shape_B, prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_scale, prob_shape_scale, prob_stride_scale); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_zero, prob_shape_zero, prob_stride_zero); + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + } + + template + CUTLASS_DEVICE void tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE void + tensormaps_cp_fence_release(TensorMapStorage& shared_tensormaps, cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_cp_fence_release."); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE void tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_fence_acquire."); + } + } + + template + CUTLASS_DEVICE InputTensors tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu new file mode 100644 index 000000000..28ed28c1e --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu @@ -0,0 +1,91 @@ +#include +#include +#include + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} + +void cutlass_w4a8_moe_mm_sm90( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk); + +void get_cutlass_w4a8_moe_mm_data_caller( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k); + +void cutlass_w4a8_moe_mm( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk) { + cutlass_w4a8_moe_mm_sm90( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size, + topk); + return; +} + +void get_cutlass_w4a8_moe_mm_data( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k) { + get_cutlass_w4a8_moe_mm_data_caller( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k); + return; +} diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh new file mode 100644 index 000000000..f926202c0 --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +template +__global__ void int4_fp8_get_group_gemm_starts( + int32_t* expert_offsets, + ElementA** a_offsets, + ElementB** b_offsets, + ElementC** out_offsets, + ElementAccumulator** a_scales_offsets, + cutlass::bfloat16_t** b_scales_offsets, + ElementA* a_base_as_int, + ElementB* b_base_as_int, + ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + cutlass::bfloat16_t* b_scales_base_as_int, + int64_t n, + int64_t k, + bool per_act_token, + bool per_out_ch) { + int expert_id = threadIdx.x; + int32_t expert_offset = expert_offsets[expert_id]; + + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0); + b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id); +} + +#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + int4_fp8_get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + out_tensors.size(1), \ + a_tensors.size(1), \ + per_act_token, \ + per_out_ch); \ + } + +namespace { + +void run_int4_fp8_get_group_gemm_starts( + torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kBFloat16); + + int num_experts = static_cast(expert_offsets.size(0)); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); + + if (false) { + } + __CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu new file mode 100644 index 000000000..cffa171cc --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu @@ -0,0 +1,240 @@ +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "w4a8_grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c + +#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c + +#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \ + struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \ + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \ + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \ + using TileShape = cute::Shape, cute::Int, cute::Int>; \ + using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ + \ + using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm; \ + }; + +#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \ + struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \ + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \ + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ + using TileShape = cute::Shape, cute::Int, cute::Int>; \ + using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ + \ + using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm; \ + }; + +GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1) +GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1) + +GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1) +GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1) +GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1) +GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1) +GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1) + +void dispatch_w4a8_moe_mm_sm90( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk) { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + + uint32_t const m = a_tensors.size(0) / topk; + uint32_t const n = d_tensors.size(1); + uint32_t const k = a_tensors.size(1); + + if (n == 4096 && k == 7168) { + // group gemm 1 + if (m <= 4) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 16) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 256) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 1024) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } + } else if (n == 7168 && k == 2048) { + // group gemm 2 + if (m <= 8) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 512) { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } + } else { + using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } +} + +} // namespace + +void cutlass_w4a8_moe_mm_sm90( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk) { + dispatch_w4a8_moe_mm_sm90( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size, + topk); +} diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh new file mode 100644 index 000000000..1252b245f --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh @@ -0,0 +1,276 @@ +#pragma once + +/** + * @file w4a8_grouped_mm_c3x.cuh + * @brief Implementation of grouped GEMM operation with int4 and fp8 mixed + * precision + * + * This file implements a grouped GEMM operation that multiplies FP8 matrices + * (A) with quantized INT4 matrices (B), applying per-block scaling factors. + * The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor + * Cores for mixed precision arithmetic. + * + * Key features: + * - Supports grouped GEMM operations with multiple experts + * - Uses FP8 (e4m3) for matrix A + * - Uses INT4 quantization for matrix B with per-block scaling + * - Implements preprocessing for INT4 encoding and scale packing + * - Optimized for Hopper architecture with Tensor Core operations + */ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp" +#include "w4a8_get_group_starts.cuh" + +using namespace cute; + +namespace { + +// Type definitions +using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type +using QuantType = cutlass::int4b_t; // 4-bit integer type +using ElementAccumulator = float; // Accumulator type +using ElementScale = cutlass::bfloat16_t; // Scale type +using ElementScalePacked = cutlass::Array; +using ElementC = cutlass::half_t; // Default output type (FP16) +using ElementD = ElementC; // Default output type (FP16) +using ProblemShape = cutlass::gemm::GroupProblemShape>; + +// Architecture-specific configurations +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; +// constexpr int TileShapeK = 512; +// using TileShape = Shape<_128, _32, cute::Int>; +// using ClusterShape = Shape<_1, _1, _1>; + +// Layout configurations +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; +using LayoutD = LayoutC; + +// Transposed layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; + +// Alignments +static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; +static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +template +struct cutlass_3x_w4a8_group_gemm { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC_Transpose*, + AlignmentC, + ElementD, + LayoutD_Transpose*, + AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput< + ArchTag, + OperatorClass, + cute::tuple, + LayoutB_Transpose*, + AlignmentB, + MmaType, + LayoutA_Transpose*, + AlignmentA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + // Define the final kernel and GEMM operation types + using GemmKernelScaleOnly = + cutlass::gemm::kernel::GemmUniversal; + + using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cute::remove_pointer_t>; + using StrideB = cute::remove_pointer_t>; + using StrideC = typename GemmKernelScaleOnly::InternalStrideC; + using StrideD = typename GemmKernelScaleOnly::InternalStrideD; + using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +}; + +/** + * @brief Main function to run int4 * fp8 grouped GEMM from PyTorch + * + * This function performs multiple GEMM operations in parallel where each + * operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B), + * applying per-channel scaling factors. It's designed for efficient execution + * on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with + * mixed precision arithmetic. + * + * The function includes preprocessing steps for both INT4 tensors and scale + * factors to ensure optimal performance and correct operation. + * + * @param d_tensors Output tensor D with shape [total_m, total_n] + * @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape + * [total_m, K] + * @param b_tensors Tensor containing all B matrices (int4 packed as int8) with + * shape [E, N, K/2] + * @param a_scales Tensor containing A matrix scale factors + * @param b_scales Tensor containing B matrix scale factors with shape [E, + * K//512, N*4] + * @param expert_offsets Tensor containing expert offsets for determining group + * boundaries (int32) + * @param problem_sizes Tensor containing problem sizes with shape [num_experts, + * 3] (M, N, K for each group) (int32) + * @param a_strides Stride information for A tensors + * @param b_strides Stride information for B tensors + * @param d_strides Stride information for D tensors + * @param s_strides Stride information for scale tensors + * @param chunk_size Size of each chunk for scales (K / number of scale chunks) + */ +// template +template +void cutlass_w4a8_group_gemm_caller( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size) { + // using Gemm = cutlass_3x_w4a8_group_gemm; + using Args = typename Gemm::GemmScaleOnly::Arguments; + + int num_experts = static_cast(expert_offsets.size(0)); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + // Check inputs + TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D"); + TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]"); + TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]"); + TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]"); + TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + + // Check tensor shapes + TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)"); + TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups"); + TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups"); + TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension"); + TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512"); + TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N"); + + // Check tensor types + TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type"); + TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)"); + TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type"); + TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type"); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a_tensors.device().index(); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + Args arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0; + fusion_args.alpha_ptr = a_scales.data_ptr(); + ; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + run_int4_fp8_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a_tensors, + b_tensors, + d_tensors, + a_scales, + b_scales); + + arguments = Args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + {static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + static_cast(s_strides.data_ptr()), + static_cast(chunk_size)}, + {fusion_args, + nullptr, + nullptr, + static_cast(out_ptrs.data_ptr()), + static_cast(d_strides.data_ptr())}, + hw_info}; + + // Instantiate and run GEMM + typename Gemm::GemmScaleOnly gemm; + size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + TORCH_CHECK(false, "GEMM implementation not supported"); + } + + status = gemm.initialize(arguments, workspace.data_ptr(), stream); + if (status != cutlass::Status::kSuccess) { + TORCH_CHECK(false, "GEMM initialization failed"); + } + + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + TORCH_CHECK(false, "GEMM execution failed"); + } +} + +} // namespace diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu new file mode 100644 index 000000000..f30f7b025 --- /dev/null +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu @@ -0,0 +1,79 @@ +#include +#include +#include + +#include + +constexpr uint64_t THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes_w4a8( + const int32_t* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, + const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = 2 * n; + problem_sizes1[expert_id * 3 + 1] = final_occurrences; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = k; + problem_sizes2[expert_id * 3 + 1] = final_occurrences; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets_w4a8( + const int32_t* __restrict__ problem_sizes1, + int32_t* expert_offsets, + int32_t* atomic_buffer, + const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3 + 1]; + expert_offsets[i + 1] = tot_offset; + } +} + +void get_cutlass_w4a8_moe_mm_data_caller( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes_w4a8<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), + n, + k); + compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index c53ecdc01..6811fdd55 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct( int64_t page_size, int64_t num_layers); +/* + * From csrc/moe/cutlass_moe/w4a8 + */ +void get_cutlass_w4a8_moe_mm_data( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k); + +void cutlass_w4a8_moe_mm( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk); + /* * From FlashInfer */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 65a037d54..27666f0f6 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -19,6 +19,7 @@ from sgl_kernel.attention import ( merge_state, merge_state_v2, ) +from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, fused_add_rmsnorm, diff --git a/sgl-kernel/python/sgl_kernel/cutlass_moe.py b/sgl-kernel/python/sgl_kernel/cutlass_moe.py new file mode 100644 index 000000000..52256bda1 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/cutlass_moe.py @@ -0,0 +1,112 @@ +import torch + + +def get_cutlass_w4a8_moe_mm_data( + topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, +): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + """ + torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k, + ) + + +def cutlass_w4a8_moe_mm( + d: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + experts_offsets: torch.tensor, + problem_sizes: torch.tensor, + a_strides: torch.tensor, + b_strides: torch.tensor, + d_strides: torch.tensor, + s_strides: torch.tensor, + chunk_size: int = 128, + topk: int = 8, +): + """ + Perform grouped matrix multiplication between int4 weights and fp8 activations. + + This function executes multiple GEMM operations in parallel, which is useful for + scenarios like Mixture of Experts (MoE) where different inputs go through different + experts. The implementation leverages NVIDIA Hopper architecture features for + optimal performance with quantized weights. + + Args: + d: Output matrices of shape [total_m, total_n] + a: Activation matrices in FP8 (float_e4m3_t) format + Each tensor should be of shape [total_m, K] in row-major layout + b: Weight matrices in packed int4 format + Each tensor should be of shape [E, N, K//2] in column-major layout + where each byte contains two 4-bit integers + a_scales: Scale factors for the inputs + b_scales: Scale factors for the quantized weights + Each tensor should be of shape [E, K//512, N*8] + experts_offsets: Tensor containing expert offsets for determining group boundaries + problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32) + a_strides: Strides information for A matrices + b_strides: Strides information for B matrices + d_strides: Strides information for D matrices + s_strides: Strides information for b_scales matrices + chunk_size: Number of elements each scale value applies to (K//512), default to 128 + + Requirements: + - All tensors must be on a CUDA device + - Requires an NVIDIA Hopper GPU (H100) + - A tensors must be in float8_e4m3fn format + - B tensors must contain packed int4 values (stored as int8) + + Note: + The function computes: D = (A * (B * scales)) + for each group of tensors in parallel + """ + + torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default( + d, + a, + b, + a_scales, + b_scales, + experts_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size, + topk, + ) diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py new file mode 100644 index 000000000..3cdd62edd --- /dev/null +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -0,0 +1,260 @@ +import pytest +import torch +from sgl_kernel import cutlass_w4a8_moe_mm + + +def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: + if int4_values_interleaved.shape[-1] % 2 != 0: + raise ValueError( + "the last dim size of int4_values_interleaved tensor must be even." + ) + + input_tensor_int8 = int4_values_interleaved.to(torch.int8) + + low_nibbles = input_tensor_int8[..., 0::2] + high_nibbles = input_tensor_int8[..., 1::2] + + packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F) + + return packed_tensor.to(torch.int8) + + +def pack_interleave(num_experts, ref_weight, ref_scale): + n, k = ref_weight.shape[1], ref_weight.shape[2] + + weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda() + w_q = weight.view((num_experts, n, k // 2)).view(torch.int8) + w_q = w_q.contiguous() + + scale_interleaved = ref_scale.reshape( + ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 + ) # [E, N, K/4, 4] + scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] + scale_interleaved = scale_interleaved.reshape( + ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 + ) # [E, K/4, N*4] + w_scale = scale_interleaved.contiguous() + + return w_q, w_scale + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +def test_int4_fp8_grouped_gemm_single_expert(batch_size): + # Test parameters + num_experts = 1 + m = batch_size # batch size + k = 512 # input dimension + n = 1024 # output dimension + torch.manual_seed(0) + dtype = torch.bfloat16 + device = "cuda" + debug = False + + print(f"\nTesting with batch_size={batch_size}") + + # Create input tensors with ones + if debug: + a = torch.ones(m, k, dtype=torch.bfloat16, device=device) + ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) + a_scale = torch.ones(1, dtype=torch.float, device=device) + ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) + else: + a = torch.randn(m, k, dtype=dtype, device=device) + ref_w = torch.randint( + -8, 8, (num_experts, n, k), dtype=torch.int8, device=device + ) + affine_coeff = 0.005 + a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02 + ref_w_scale = ( + torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) + * affine_coeff + ) + + w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale) + + # Create expert offsets and problem sizes + expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device) + problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device) + + a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64) + c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64) + b_strides = a_strides + s_strides = c_strides + + # Quantize input + a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device) + + # Create output tensor + c = torch.empty((m, n), dtype=torch.float16, device=device) + cutlass_w4a8_moe_mm( + c, + a_q, + w, + a_scale, + w_scale, + expert_offsets[:-1], + problem_sizes, + a_strides, + b_strides, + c_strides, + s_strides, + 128, + 8, + ) + c = c.to(dtype) + + # Reference implementation + experts_selection_result = torch.full((m,), 0) + c_ref = ref_grouped_gemm( + c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result + ) + + # Compare results + try: + torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1) + except AssertionError as e: + # torch.set_printoptions(threshold=10_000) + print(f" FAILURE: tensors are NOT close.") + print(f" Ref tensor: {c_ref.flatten()}") + print(f" Cutlass tensor: {c.flatten()}") + print( + f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}" + ) + print( + f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}" + ) + print(f" AssertionError: {e}") + raise + + +@pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) +@pytest.mark.parametrize("k", [512, 1024]) +@pytest.mark.parametrize("n", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) +def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): + torch.manual_seed(0) + dtype = torch.bfloat16 + device = "cuda" + debug = False + + print( + f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}" + ) + + if debug: + a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device) + ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) + a_scale = torch.ones(1, dtype=torch.float, device=device) + ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) + else: + a = torch.randn(batch_size, k, dtype=dtype, device=device) + ref_w = torch.randint( + -8, 8, (num_experts, n, k), dtype=torch.int8, device=device + ) + affine_coeff = 0.005 + a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02 + ref_w_scale = ( + torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) + * affine_coeff + ) + + w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale) + + # random select experts + experts_selection_result = torch.randint( + 0, num_experts, (batch_size,), device=device + ) + permutation = torch.argsort(experts_selection_result) + expert_token_counts = torch.bincount( + experts_selection_result, minlength=num_experts + ) + + # Create problem sizes and offsets for active experts + problem_sizes = [] + for i in range(num_experts): + problem_sizes.append([n, expert_token_counts[i].item(), k]) + problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device) + + expert_offsets = [] + offset = 0 + for i in range(num_experts): + expert_offsets.append(offset) + offset += problem_sizes[i][1].item() + expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device) + + # Permute input and quantize + a_perm = a[permutation] + a_q_perm = ( + torch.clamp((a_perm / a_scale), -448.0, 448.0) + .to(torch.float8_e4m3fn) + .to(device) + ) + + # Create stride tensors + a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64) + c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64) + b_strides = a_strides + s_strides = c_strides + + c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device) + cutlass_w4a8_moe_mm( + c_perm, + a_q_perm, + w, + a_scale, + w_scale, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + c_strides, + s_strides, + 128, + 8, + ) + + # Un-permute the result + c = torch.empty_like(c_perm) + c[permutation] = c_perm + c = c.to(dtype) + + c_ref = ref_grouped_gemm( + c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result + ) + + # Compare results + try: + torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1) + except AssertionError as e: + print(f" FAILURE: tensors are NOT close.") + print( + f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}" + ) + print( + f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}" + ) + print(f" AssertionError: {e}") + raise + + +def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result): + dtype = torch.bfloat16 + c_ref = torch.zeros_like(c) + a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn) + for i in range(num_experts): + token_idx = torch.where(experts_selection_result == i)[0] + if len(token_idx) == 0: + continue + a = a_q[token_idx] + + ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float) + ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype) + c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale + c = c.to(dtype) + c_ref[token_idx] = c.to(dtype) + + return c_ref + + +if __name__ == "__main__": + pytest.main([__file__])