[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)

Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
SijiaYang
2025-07-05 11:50:12 +08:00
committed by GitHub
parent cb432f1770
commit da3890e82a
16 changed files with 3602 additions and 0 deletions

View File

@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"int num_layers) -> ()");
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From FlashInfer
*/

View File

@@ -0,0 +1,482 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/copy_sm90.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/util/type_traits.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective::detail {
template <class Collective>
struct MixedGroupedGemmInputUtils {
private:
using KernelSchedule = typename Collective::KernelSchedule;
using ConversionMode = typename Collective::ConversionMode;
using SmemLayoutA = typename Collective::SmemLayoutA;
using SmemLayoutB = typename Collective::SmemLayoutB;
using SmemLayoutScale = typename Collective::SmemLayoutScale;
using SwappedElementA = typename Collective::SwappedElementA;
using SwappedElementB = typename Collective::SwappedElementB;
using RealSwappedElementA = typename Collective::RealSwappedElementA;
using RealSwappedElementB = typename Collective::RealSwappedElementB;
using ElementScale = typename Collective::ElementScale;
using ElementZero = typename Collective::ElementZero;
using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale;
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
static constexpr auto ModeHasScales = Collective::ModeHasScales;
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
public:
static constexpr auto elements_per_smem_scale() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
}
}
static constexpr auto elements_per_smem_zero() {
if constexpr (
KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale) {
return 0;
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
}
}
// These methods use some the public members of the class. For that reason, we define them after the public section.
static constexpr uint32_t compute_tma_transaction_bytes_mk() {
return cutlass::bits_to_bytes(
size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementA>));
}
static constexpr uint32_t compute_tma_transaction_bytes_nk() {
return cutlass::bits_to_bytes(
size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementB>));
}
static constexpr uint32_t compute_tma_transaction_bytes_extra() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return scale_tx_bytes;
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// Scale and zero share smem layout
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
return scale_tx_bytes + zero_tx_bytes;
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
/// Utilities to copy A and extra inputs from smem to RF
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
CUTLASS_DEVICE static void copy_tensors_MK(
SmemTiledCopyA const& smem_tiled_copy_A,
TensorASmemView const& tCsA,
TensorACopyView& tCrA_copy_view,
cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views,
int k_block,
int read_stage) {
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
if (k_block == 0) {
// We are starting a new k-tile so copy the scale
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
}
// The core converter uses a lookup table to converts i4 -> 8 bit value.
template <
class EngineIn,
class LayoutIn,
class EngineOut,
class LayoutOut,
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries
Tensor<EngineIn, LayoutIn> const& src,
Tensor<EngineOut, LayoutOut>&& dst,
Tensor<EngineScale, LayoutScale> const& scales_neg,
Tensor<EngineScale, LayoutScale> const& scales_pos) {
lookup_table_convert(src, dst, scales_neg, scales_pos);
}
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class EngineScale, class LayoutScale>
CUTLASS_DEVICE static void lookup_table_convert(
Tensor<EngineIn, LayoutIn> const& src,
Tensor<EngineOut, LayoutOut>& dst,
Tensor<EngineScale, LayoutScale> const& scales_neg,
Tensor<EngineScale, LayoutScale> const& scales_pos) {
constexpr int N = cute::cosize(LayoutIn{});
static_assert(N == 4 || N == 8);
static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale.");
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, N / 4, sizeof(DstArray)>;
// View the input as reg
auto&& src_reg = cute::recast<uint32_t>(src)(0);
auto&& r = cute::recast<RegArray>(dst)(0);
// Determines if to get from the signed or unsigned candidates
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
asm volatile(
"{\n"
" lop3.b32 %0, %1, %2, %3, %4;\n"
"}\n"
: "=r"(sign)
: "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut));
sign = sign >> 1;
// Ignore sign bit when indexing into LUT
uint32_t lut_idx = src_reg & 0x77777777;
Tensor scales_neg_ = cute::filter(scales_neg);
Tensor scales_pos_ = cute::filter(scales_pos);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) {
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_(i));
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_pos_(i));
asm volatile(
"{\n"
" .reg .b32 pos, neg ;\n"
" prmt .b32 neg, %3, %4, %1 ;\n"
" prmt .b32 pos, %5, %6, %1 ;\n"
" prmt .b32 %0, pos, neg, %2 ;\n"
"}\n"
: "=r"(r[i])
: "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]));
}
}
/// Utilities to dequantize A.
template <class Layout>
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) {
static_assert(
shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0,
"At least 4 adjacent weights in a thread must share the same scale.");
}
template <class Engine, class Layout>
CUTLASS_DEVICE static void static_check_scale(Tensor<Engine, Layout> const& tensor) {
static_check_scale(flatten(Layout{}));
}
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock(
Tensor<EngineIn, LayoutIn> const& tCrA_load,
Tensor<EngineOut, LayoutOut>& tCrA_mma,
cute::tuple<Ts...>& partitioned_extra_info,
int const k_block) {
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
Tensor src = tCrA_load(_, _, k_block);
Tensor dst = tCrA_mma(_, _, k_block);
CUTE_STATIC_ASSERT_V(
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
// try to make the size of the first mode equal to 32bit
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
}
} else if constexpr (UseScaleLookupTable) {
constexpr int num_elements = decltype(size(src))::value;
static_assert(
is_same_v<RealSwappedElementA, cutlass::int4b_t>,
"Lookup table only supports int4 being the quant type now.");
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
static_assert(
num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
Tensor tCrS_neg = cute::get<1>(partitioned_extra_info);
auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed
Tensor scales_neg = tCrS_neg(_, _, k_block);
Tensor scales_pos = tCrS_pos(_, _, k_block);
CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg));
static_check_scale(scales_neg);
static_check_scale(scales_pos);
Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int<NumValPerSrcReg>{}));
Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int<NumValPerSrcReg>{}));
if (k_block == 0) {
Tensor scales_neg_vm_ = filter(scales_neg_vm);
Tensor scales_pos_vm_ = filter(scales_pos_vm);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) {
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2>&>(scales_pos_vm_(i));
constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
asm volatile(
"{\n"
" lop3 .b32 %0, %2, %4, %5, %6;\n"
" xor .b32 %1, %3, %5; \n"
"}\n"
: "=r"(scale_pos_[0]), "=r"(scale_pos_[1])
: "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut));
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i));
}
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
if constexpr (is_same_v<DstType, ElementScale>) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
dst_vm(j, i) *= scales_vm(j, i);
}
}
} else {
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), stage);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
stage(j) *= scales_vm(j, i);
}
LayoutAwareConvert(stage, dst_vm(_, i));
}
}
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
if constexpr (is_same_v<DstType, ElementScale>) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i);
}
}
} else {
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), stage);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i);
}
LayoutAwareConvert(stage, dst_vm(_, i));
}
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "No A data is loaded.");
}
}
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void convert_A_kblock(
Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, int const k_block) {
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
Tensor src = tCrA_load(_, _, k_block);
Tensor dst = tCrA_mma(_, _, k_block);
CUTE_STATIC_ASSERT_V(
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
// try to make the size of the first mode equal to 32bit
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
// KernelConversionMode == ConversionMode::DirectConvert
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
}
}
/// Utilities for any additional inputs inside of the TMA load
template <class Params, class TensorStorage, class... Ts>
CUTLASS_DEVICE static auto partition_extra_tma_inputs(
Params const& mainloop_params,
cute::tuple<Ts...> const& load_inputs,
TensorStorage& shared_tensors,
uint2 const& cluster_local_block_id,
int const m_coord,
int const l_coord) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple();
} else if constexpr (ModeHasScales) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gS_mkl = get<2>(load_inputs);
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tSgS, tSsS);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ =
make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gZ_mkl = get<3>(load_inputs);
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
template <class ThreadMma, class TensorStorage>
CUTLASS_DEVICE static auto
partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
} else if constexpr (UseScaleLookupTable) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
return cute::make_tuple(tCsS, tCrS);
} else if constexpr (ModeHasScales) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tCsS, tCrS);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(
make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout());
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
/// Returns the tiled copy and copy views for the extra inputs.
template <class TiledMma, class... Ts>
CUTLASS_DEVICE static auto retile_extra_mma_info(
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info, int const warp_group_thread_idx) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
};
} // namespace cutlass::gemm::collective::detail

View File

@@ -0,0 +1,278 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/pipeline/sm90_pipeline.hpp"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_RS
template <
class ElementA_,
class GmemLayoutATag_,
int AlignmentA,
class ElementB_,
class GmemLayoutBTag_,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType>
struct CollectiveBuilderMixedInput<
arch::Sm90,
arch::OpClassTensorOp,
ElementA_,
GmemLayoutATag_,
AlignmentA,
ElementB_,
GmemLayoutBTag_,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
// ConvertAndScale and ConvertAndScaleWithZero
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
// DirectConvert
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
private:
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
// Determine if mixed input types.
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>;
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
public:
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
static_assert(
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
template <class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
return t;
} else {
if constexpr (cute::is_pointer_v<T>) {
return &cute::stride(*t);
} else {
return cute::stride(t);
}
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA =
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
using ElementPairB =
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
// operands.
static constexpr bool SwapAB =
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
static constexpr bool IsWarpSpecializedTransposeB =
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
// For fp32 types, map to tf32 MMA value type.
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
// Handle mixed dtypes and MMA.
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
// Always the same for element B.
using RealElementBMma = RealElementB;
static_assert(
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
"Mixed input GEMM does not support MN major layout except for 16bit");
using AtomLayoutMNK = cute::conditional_t<
cute::is_any_of_v<
KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>,
Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<
RealElementAMma,
RealElementBMma,
ElementAccumulator,
TileShape_MNK,
GMMA::Major::K,
GMMA::Major::K>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
GmmaMajorA,
ElementAMma,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
GmmaMajorB,
ElementBMma,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
// Handle mixed dtype array GEMM's size of tensor map storage.
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int PipelineStages =
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
Sm90ReducedSmemCapacityBytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{})
: detail::compute_stage_count_or_override_single_affine_transformed_input<
detail::sm90_smem_capacity_bytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{}))
: detail::compute_stage_count_or_override<
detail::sm90_smem_capacity_bytes,
ElementAMma,
ElementBMma,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsMixedInput,
cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
GmemLayoutATag_,
TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
GmemLayoutBTag_,
TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMmaArrayMixedInput<
DispatchPolicy,
TileShape_MNK,
ElementPairA,
StrideA,
ElementPairB,
StrideB,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity>;
static_assert(
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,52 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType,
class Enable = void>
struct CollectiveBuilderMixedInput {
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class DispatchPolicy,
class TileShape,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class TiledMma,
class GmemTiledCopyA,
class SmemLayoutAtomA,
class SmemCopyAtomA,
class TransformA,
class GmemTiledCopyB,
class SmemLayoutAtomB,
class SmemCopyAtomB,
class TransformB>
struct CollectiveMmaArrayMixedInput {
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,91 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
cutlass_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
return;
}
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
get_cutlass_w4a8_moe_mm_data_caller(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}

View File

@@ -0,0 +1,92 @@
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
__global__ void int4_fp8_get_group_gemm_starts(
int32_t* expert_offsets,
ElementA** a_offsets,
ElementB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
cutlass::bfloat16_t** b_scales_offsets,
ElementA* a_base_as_int,
ElementB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
cutlass::bfloat16_t* b_scales_base_as_int,
int64_t n,
int64_t k,
bool per_act_token,
bool per_out_ch) {
int expert_id = threadIdx.x;
int32_t expert_offset = expert_offsets[expert_id];
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
out_tensors.size(1), \
a_tensors.size(1), \
per_act_token, \
per_out_ch); \
}
namespace {
void run_int4_fp8_get_group_gemm_starts(
torch::Tensor const& expert_offsets,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace

View File

@@ -0,0 +1,240 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
using namespace cute;
namespace {
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1);
if (n == 4096 && k == 7168) {
// group gemm 1
if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 2048) {
// group gemm 2
if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
}
} // namespace
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
dispatch_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
}

View File

@@ -0,0 +1,276 @@
#pragma once
/**
* @file w4a8_grouped_mm_c3x.cuh
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
* precision
*
* This file implements a grouped GEMM operation that multiplies FP8 matrices
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
* Cores for mixed precision arithmetic.
*
* Key features:
* - Supports grouped GEMM operations with multiple experts
* - Uses FP8 (e4m3) for matrix A
* - Uses INT4 quantization for matrix B with per-block scaling
* - Implements preprocessing for INT4 encoding and scale packing
* - Optimized for Hopper architecture with Tensor Core operations
*/
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
#include "w4a8_get_group_starts.cuh"
using namespace cute;
namespace {
// Type definitions
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// constexpr int TileShapeK = 512;
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
// using ClusterShape = Shape<_1, _1, _1>;
// Layout configurations
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Transposed layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm {
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC_Transpose*,
AlignmentC,
ElementD,
LayoutD_Transpose*,
AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
ArchTag,
OperatorClass,
cute::tuple<QuantType, ElementScalePacked>,
LayoutB_Transpose*,
AlignmentB,
MmaType,
LayoutA_Transpose*,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
// Define the final kernel and GEMM operation types
using GemmKernelScaleOnly =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
};
/**
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
*
* This function performs multiple GEMM operations in parallel where each
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
* applying per-channel scaling factors. It's designed for efficient execution
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
* mixed precision arithmetic.
*
* The function includes preprocessing steps for both INT4 tensors and scale
* factors to ensure optimal performance and correct operation.
*
* @param d_tensors Output tensor D with shape [total_m, total_n]
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
* [total_m, K]
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
* shape [E, N, K/2]
* @param a_scales Tensor containing A matrix scale factors
* @param b_scales Tensor containing B matrix scale factors with shape [E,
* K//512, N*4]
* @param expert_offsets Tensor containing expert offsets for determining group
* boundaries (int32)
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
* 3] (M, N, K for each group) (int32)
* @param a_strides Stride information for A tensors
* @param b_strides Stride information for B tensors
* @param d_strides Stride information for D tensors
* @param s_strides Stride information for scale tensors
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
*/
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
template <typename Gemm>
void cutlass_w4a8_group_gemm_caller(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
using Args = typename Gemm::GemmScaleOnly::Arguments;
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
// Check inputs
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
// Check tensor shapes
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_tensors.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f;
fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
run_int4_fp8_get_group_gemm_starts(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_tensors,
b_tensors,
d_tensors,
a_scales,
b_scales);
arguments = Args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)},
{fusion_args,
nullptr,
nullptr,
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
hw_info};
// Instantiate and run GEMM
typename Gemm::GemmScaleOnly gemm;
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM implementation not supported");
}
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM initialization failed");
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM execution failed");
}
}
} // namespace

View File

@@ -0,0 +1,79 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes_w4a8(
const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
__global__ void compute_expert_offsets_w4a8(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset;
}
}
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}