[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com> Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
@@ -249,6 +249,9 @@ set(SOURCES
|
|||||||
"csrc/speculative/speculative_sampling.cu"
|
"csrc/speculative/speculative_sampling.cu"
|
||||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||||
"csrc/kvcacheio/transfer.cu"
|
"csrc/kvcacheio/transfer.cu"
|
||||||
|
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||||
|
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||||
|
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||||
"csrc/common_extension.cc"
|
"csrc/common_extension.cc"
|
||||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||||
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
|
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
|
||||||
|
|||||||
@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"int num_layers) -> ()");
|
"int num_layers) -> ()");
|
||||||
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
|
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/moe/cutlass_moe/w4a8
|
||||||
|
*/
|
||||||
|
m.def(
|
||||||
|
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||||
|
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||||
|
" Tensor! input_permutation, "
|
||||||
|
" Tensor! output_permutation, int num_experts, "
|
||||||
|
" int n, int k) -> ()");
|
||||||
|
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||||
|
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||||
|
" Tensor problem_sizes, Tensor a_strides, "
|
||||||
|
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||||
|
" int chunk_size, int topk) -> ()");
|
||||||
|
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,482 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/arch/copy_sm90.hpp"
|
||||||
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
|
#include "cute/util/type_traits.hpp"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective::detail {
|
||||||
|
|
||||||
|
template <class Collective>
|
||||||
|
struct MixedGroupedGemmInputUtils {
|
||||||
|
private:
|
||||||
|
using KernelSchedule = typename Collective::KernelSchedule;
|
||||||
|
using ConversionMode = typename Collective::ConversionMode;
|
||||||
|
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||||
|
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||||
|
using SmemLayoutScale = typename Collective::SmemLayoutScale;
|
||||||
|
using SwappedElementA = typename Collective::SwappedElementA;
|
||||||
|
using SwappedElementB = typename Collective::SwappedElementB;
|
||||||
|
using RealSwappedElementA = typename Collective::RealSwappedElementA;
|
||||||
|
using RealSwappedElementB = typename Collective::RealSwappedElementB;
|
||||||
|
using ElementScale = typename Collective::ElementScale;
|
||||||
|
using ElementZero = typename Collective::ElementZero;
|
||||||
|
using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale;
|
||||||
|
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||||
|
static constexpr auto ModeHasScales = Collective::ModeHasScales;
|
||||||
|
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static constexpr auto elements_per_smem_scale() {
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
return 0;
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
return cute::cosize_v<SmemLayoutScale>;
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr auto elements_per_smem_zero() {
|
||||||
|
if constexpr (
|
||||||
|
KernelConversionMode == ConversionMode::DirectConvert ||
|
||||||
|
KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
return 0;
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
return cute::cosize_v<SmemLayoutScale>;
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// These methods use some the public members of the class. For that reason, we define them after the public section.
|
||||||
|
static constexpr uint32_t compute_tma_transaction_bytes_mk() {
|
||||||
|
return cutlass::bits_to_bytes(
|
||||||
|
size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementA>));
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr uint32_t compute_tma_transaction_bytes_nk() {
|
||||||
|
return cutlass::bits_to_bytes(
|
||||||
|
size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementB>));
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr uint32_t compute_tma_transaction_bytes_extra() {
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
return 0;
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(
|
||||||
|
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||||
|
static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
|
||||||
|
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
return scale_tx_bytes;
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
// Scale and zero share smem layout
|
||||||
|
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(
|
||||||
|
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||||
|
static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
|
||||||
|
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
|
||||||
|
return scale_tx_bytes + zero_tx_bytes;
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilities to copy A and extra inputs from smem to RF
|
||||||
|
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
|
||||||
|
CUTLASS_DEVICE static void copy_tensors_MK(
|
||||||
|
SmemTiledCopyA const& smem_tiled_copy_A,
|
||||||
|
TensorASmemView const& tCsA,
|
||||||
|
TensorACopyView& tCrA_copy_view,
|
||||||
|
cute::tuple<Ts...> const& partitioned_mma_extra_info,
|
||||||
|
cute::tuple<Us...> const& tiled_copy_and_views,
|
||||||
|
int k_block,
|
||||||
|
int read_stage) {
|
||||||
|
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
|
||||||
|
|
||||||
|
if (k_block == 0) {
|
||||||
|
// We are starting a new k-tile so copy the scale
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
// nothing to do
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
|
||||||
|
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
|
||||||
|
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
|
||||||
|
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
// Nothing extra to do
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
|
||||||
|
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
|
||||||
|
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||||
|
template <
|
||||||
|
class EngineIn,
|
||||||
|
class LayoutIn,
|
||||||
|
class EngineOut,
|
||||||
|
class LayoutOut,
|
||||||
|
class EngineScale,
|
||||||
|
class LayoutScale>
|
||||||
|
CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries
|
||||||
|
Tensor<EngineIn, LayoutIn> const& src,
|
||||||
|
Tensor<EngineOut, LayoutOut>&& dst,
|
||||||
|
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||||
|
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||||
|
lookup_table_convert(src, dst, scales_neg, scales_pos);
|
||||||
|
}
|
||||||
|
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class EngineScale, class LayoutScale>
|
||||||
|
CUTLASS_DEVICE static void lookup_table_convert(
|
||||||
|
Tensor<EngineIn, LayoutIn> const& src,
|
||||||
|
Tensor<EngineOut, LayoutOut>& dst,
|
||||||
|
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||||
|
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||||
|
constexpr int N = cute::cosize(LayoutIn{});
|
||||||
|
static_assert(N == 4 || N == 8);
|
||||||
|
static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale.");
|
||||||
|
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
|
||||||
|
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
|
||||||
|
using RegArray = cutlass::AlignedArray<uint32_t, N / 4, sizeof(DstArray)>;
|
||||||
|
|
||||||
|
// View the input as reg
|
||||||
|
auto&& src_reg = cute::recast<uint32_t>(src)(0);
|
||||||
|
auto&& r = cute::recast<RegArray>(dst)(0);
|
||||||
|
|
||||||
|
// Determines if to get from the signed or unsigned candidates
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(sign)
|
||||||
|
: "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut));
|
||||||
|
sign = sign >> 1;
|
||||||
|
|
||||||
|
// Ignore sign bit when indexing into LUT
|
||||||
|
uint32_t lut_idx = src_reg & 0x77777777;
|
||||||
|
Tensor scales_neg_ = cute::filter(scales_neg);
|
||||||
|
Tensor scales_pos_ = cute::filter(scales_pos);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) {
|
||||||
|
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_(i));
|
||||||
|
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_pos_(i));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .b32 pos, neg ;\n"
|
||||||
|
" prmt .b32 neg, %3, %4, %1 ;\n"
|
||||||
|
" prmt .b32 pos, %5, %6, %1 ;\n"
|
||||||
|
" prmt .b32 %0, pos, neg, %2 ;\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(r[i])
|
||||||
|
: "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilities to dequantize A.
|
||||||
|
template <class Layout>
|
||||||
|
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) {
|
||||||
|
static_assert(
|
||||||
|
shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0,
|
||||||
|
"At least 4 adjacent weights in a thread must share the same scale.");
|
||||||
|
}
|
||||||
|
template <class Engine, class Layout>
|
||||||
|
CUTLASS_DEVICE static void static_check_scale(Tensor<Engine, Layout> const& tensor) {
|
||||||
|
static_check_scale(flatten(Layout{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||||
|
CUTLASS_DEVICE static void dequantize_A_kblock(
|
||||||
|
Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||||
|
Tensor<EngineOut, LayoutOut>& tCrA_mma,
|
||||||
|
cute::tuple<Ts...>& partitioned_extra_info,
|
||||||
|
int const k_block) {
|
||||||
|
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||||
|
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||||
|
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||||
|
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||||
|
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||||
|
using SrcType = typename EngineIn::value_type;
|
||||||
|
using DstType = typename EngineOut::value_type;
|
||||||
|
|
||||||
|
Tensor src = tCrA_load(_, _, k_block);
|
||||||
|
Tensor dst = tCrA_mma(_, _, k_block);
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||||
|
// try to make the size of the first mode equal to 32bit
|
||||||
|
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||||
|
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||||
|
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||||
|
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||||
|
}
|
||||||
|
} else if constexpr (UseScaleLookupTable) {
|
||||||
|
constexpr int num_elements = decltype(size(src))::value;
|
||||||
|
static_assert(
|
||||||
|
is_same_v<RealSwappedElementA, cutlass::int4b_t>,
|
||||||
|
"Lookup table only supports int4 being the quant type now.");
|
||||||
|
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
|
||||||
|
static_assert(
|
||||||
|
num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
|
||||||
|
|
||||||
|
Tensor tCrS_neg = cute::get<1>(partitioned_extra_info);
|
||||||
|
auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed
|
||||||
|
Tensor scales_neg = tCrS_neg(_, _, k_block);
|
||||||
|
Tensor scales_pos = tCrS_pos(_, _, k_block);
|
||||||
|
CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg));
|
||||||
|
|
||||||
|
static_check_scale(scales_neg);
|
||||||
|
static_check_scale(scales_pos);
|
||||||
|
Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int<NumValPerSrcReg>{}));
|
||||||
|
Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int<NumValPerSrcReg>{}));
|
||||||
|
|
||||||
|
if (k_block == 0) {
|
||||||
|
Tensor scales_neg_vm_ = filter(scales_neg_vm);
|
||||||
|
Tensor scales_pos_vm_ = filter(scales_pos_vm);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) {
|
||||||
|
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
|
||||||
|
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2>&>(scales_pos_vm_(i));
|
||||||
|
constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3 .b32 %0, %2, %4, %5, %6;\n"
|
||||||
|
" xor .b32 %1, %3, %5; \n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(scale_pos_[0]), "=r"(scale_pos_[1])
|
||||||
|
: "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i));
|
||||||
|
}
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||||
|
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||||
|
dst_vm(j, i) *= scales_vm(j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), stage);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||||
|
stage(j) *= scales_vm(j, i);
|
||||||
|
}
|
||||||
|
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||||
|
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||||
|
Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block);
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||||
|
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||||
|
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||||
|
dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), stage);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||||
|
stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i);
|
||||||
|
}
|
||||||
|
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "No A data is loaded.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||||
|
CUTLASS_DEVICE static void convert_A_kblock(
|
||||||
|
Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, int const k_block) {
|
||||||
|
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||||
|
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||||
|
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||||
|
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||||
|
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||||
|
using SrcType = typename EngineIn::value_type;
|
||||||
|
|
||||||
|
Tensor src = tCrA_load(_, _, k_block);
|
||||||
|
Tensor dst = tCrA_mma(_, _, k_block);
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||||
|
// try to make the size of the first mode equal to 32bit
|
||||||
|
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||||
|
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||||
|
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||||
|
|
||||||
|
// KernelConversionMode == ConversionMode::DirectConvert
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||||
|
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilities for any additional inputs inside of the TMA load
|
||||||
|
template <class Params, class TensorStorage, class... Ts>
|
||||||
|
CUTLASS_DEVICE static auto partition_extra_tma_inputs(
|
||||||
|
Params const& mainloop_params,
|
||||||
|
cute::tuple<Ts...> const& load_inputs,
|
||||||
|
TensorStorage& shared_tensors,
|
||||||
|
uint2 const& cluster_local_block_id,
|
||||||
|
int const m_coord,
|
||||||
|
int const l_coord) {
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
return cute::make_tuple();
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
Tensor sS =
|
||||||
|
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor gS_mkl = get<2>(load_inputs);
|
||||||
|
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
|
||||||
|
Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||||
|
|
||||||
|
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
|
||||||
|
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
return cute::make_tuple(tSgS, tSsS);
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
Tensor sZ =
|
||||||
|
make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor gZ_mkl = get<3>(load_inputs);
|
||||||
|
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
|
||||||
|
Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||||
|
|
||||||
|
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
|
||||||
|
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
|
||||||
|
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
|
||||||
|
template <class ThreadMma, class TensorStorage>
|
||||||
|
CUTLASS_DEVICE static auto
|
||||||
|
partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) {
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
// nothing to do
|
||||||
|
return cute::make_tuple();
|
||||||
|
} else if constexpr (UseScaleLookupTable) {
|
||||||
|
Tensor sS =
|
||||||
|
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||||
|
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||||
|
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||||
|
|
||||||
|
return cute::make_tuple(tCsS, tCrS);
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
Tensor sS =
|
||||||
|
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||||
|
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||||
|
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||||
|
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
return cute::make_tuple(tCsS, tCrS);
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
Tensor sZ = make_tensor(
|
||||||
|
make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||||
|
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
|
||||||
|
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout());
|
||||||
|
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the tiled copy and copy views for the extra inputs.
|
||||||
|
template <class TiledMma, class... Ts>
|
||||||
|
CUTLASS_DEVICE static auto retile_extra_mma_info(
|
||||||
|
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info, int const warp_group_thread_idx) {
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||||
|
// nothing to do
|
||||||
|
return cute::make_tuple();
|
||||||
|
} else if constexpr (ModeHasScales) {
|
||||||
|
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
|
||||||
|
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
|
||||||
|
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||||
|
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||||
|
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
|
||||||
|
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||||
|
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||||
|
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective::detail
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||||
|
|
||||||
|
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||||
|
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||||
|
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// GMMA_TMA_WS_RS
|
||||||
|
template <
|
||||||
|
class ElementA_,
|
||||||
|
class GmemLayoutATag_,
|
||||||
|
int AlignmentA,
|
||||||
|
class ElementB_,
|
||||||
|
class GmemLayoutBTag_,
|
||||||
|
int AlignmentB,
|
||||||
|
class ElementAccumulator,
|
||||||
|
class TileShape_MNK,
|
||||||
|
class ClusterShape_MNK,
|
||||||
|
class StageCountType,
|
||||||
|
class KernelScheduleType>
|
||||||
|
struct CollectiveBuilderMixedInput<
|
||||||
|
arch::Sm90,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ElementA_,
|
||||||
|
GmemLayoutATag_,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB_,
|
||||||
|
GmemLayoutBTag_,
|
||||||
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape_MNK,
|
||||||
|
ClusterShape_MNK,
|
||||||
|
StageCountType,
|
||||||
|
KernelScheduleType,
|
||||||
|
cute::enable_if_t<
|
||||||
|
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
|
||||||
|
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
|
||||||
|
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
|
||||||
|
// ConvertAndScale and ConvertAndScaleWithZero
|
||||||
|
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
|
||||||
|
// DirectConvert
|
||||||
|
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
|
||||||
|
private:
|
||||||
|
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
|
||||||
|
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
|
||||||
|
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
|
||||||
|
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
|
||||||
|
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
|
||||||
|
// Determine if mixed input types.
|
||||||
|
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
|
||||||
|
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
|
||||||
|
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
|
||||||
|
KernelScheduleType,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedPingpong>;
|
||||||
|
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
|
||||||
|
|
||||||
|
public:
|
||||||
|
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
|
||||||
|
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
|
||||||
|
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
|
||||||
|
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||||
|
|
||||||
|
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static auto get_stride(T const& t) {
|
||||||
|
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
|
||||||
|
return t;
|
||||||
|
} else {
|
||||||
|
if constexpr (cute::is_pointer_v<T>) {
|
||||||
|
return &cute::stride(*t);
|
||||||
|
} else {
|
||||||
|
return cute::stride(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
|
||||||
|
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
|
||||||
|
|
||||||
|
using ElementPairA =
|
||||||
|
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
|
||||||
|
using ElementPairB =
|
||||||
|
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
|
||||||
|
|
||||||
|
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||||
|
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
|
||||||
|
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
|
||||||
|
|
||||||
|
static_assert(is_static<TileShape_MNK>::value);
|
||||||
|
static_assert(is_static<ClusterShape_MNK>::value);
|
||||||
|
static_assert(
|
||||||
|
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||||
|
"Should meet TMA alignment requirement\n");
|
||||||
|
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||||
|
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||||
|
#endif
|
||||||
|
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
|
||||||
|
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
|
||||||
|
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
|
||||||
|
// operands.
|
||||||
|
static constexpr bool SwapAB =
|
||||||
|
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
|
||||||
|
static constexpr bool IsWarpSpecializedTransposeB =
|
||||||
|
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
|
||||||
|
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
|
||||||
|
|
||||||
|
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
|
||||||
|
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
|
||||||
|
|
||||||
|
// For fp32 types, map to tf32 MMA value type.
|
||||||
|
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||||
|
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||||
|
|
||||||
|
// Handle mixed dtypes and MMA.
|
||||||
|
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
|
||||||
|
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
|
||||||
|
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
|
||||||
|
// Always the same for element B.
|
||||||
|
using RealElementBMma = RealElementB;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
|
||||||
|
"Mixed input GEMM does not support MN major layout except for 16bit");
|
||||||
|
|
||||||
|
using AtomLayoutMNK = cute::conditional_t<
|
||||||
|
cute::is_any_of_v<
|
||||||
|
KernelScheduleType,
|
||||||
|
KernelTmaWarpSpecializedCooperative,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedCooperative>,
|
||||||
|
Layout<Shape<_2, _1, _1>>,
|
||||||
|
Layout<Shape<_1, _1, _1>>>;
|
||||||
|
|
||||||
|
using TiledMma = decltype(cute::make_tiled_mma(
|
||||||
|
cute::GMMA::rs_op_selector<
|
||||||
|
RealElementAMma,
|
||||||
|
RealElementBMma,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape_MNK,
|
||||||
|
GMMA::Major::K,
|
||||||
|
GMMA::Major::K>(),
|
||||||
|
AtomLayoutMNK{}));
|
||||||
|
|
||||||
|
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||||
|
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
|
||||||
|
GmmaMajorA,
|
||||||
|
ElementAMma,
|
||||||
|
decltype(cute::get<0>(TileShape_MNK{})),
|
||||||
|
decltype(cute::get<2>(TileShape_MNK{})),
|
||||||
|
IsWarpSpecializedTransposeB>());
|
||||||
|
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
|
||||||
|
GmmaMajorB,
|
||||||
|
ElementBMma,
|
||||||
|
decltype(cute::get<1>(TileShape_MNK{})),
|
||||||
|
decltype(cute::get<2>(TileShape_MNK{})),
|
||||||
|
IsWarpSpecializedTransposeB>());
|
||||||
|
|
||||||
|
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
|
||||||
|
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
|
||||||
|
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
|
||||||
|
|
||||||
|
// Handle mixed dtype array GEMM's size of tensor map storage.
|
||||||
|
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
|
||||||
|
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||||
|
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
|
||||||
|
|
||||||
|
static constexpr int PipelineStages =
|
||||||
|
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||||
|
Sm90ReducedSmemCapacityBytes,
|
||||||
|
RealElementA,
|
||||||
|
RealElementB,
|
||||||
|
ElementScale,
|
||||||
|
ElementZero,
|
||||||
|
TileShape_MNK,
|
||||||
|
StageCountType::bytes,
|
||||||
|
SmemAlignment>(StageCountType{})
|
||||||
|
: detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||||
|
detail::sm90_smem_capacity_bytes,
|
||||||
|
RealElementA,
|
||||||
|
RealElementB,
|
||||||
|
ElementScale,
|
||||||
|
ElementZero,
|
||||||
|
TileShape_MNK,
|
||||||
|
StageCountType::bytes,
|
||||||
|
SmemAlignment>(StageCountType{}))
|
||||||
|
: detail::compute_stage_count_or_override<
|
||||||
|
detail::sm90_smem_capacity_bytes,
|
||||||
|
ElementAMma,
|
||||||
|
ElementBMma,
|
||||||
|
TileShape_MNK,
|
||||||
|
StageCountType::bytes,
|
||||||
|
SmemAlignment>(StageCountType{});
|
||||||
|
|
||||||
|
using DispatchPolicy = cute::conditional_t<
|
||||||
|
IsMixedInput,
|
||||||
|
cute::conditional_t<
|
||||||
|
IsArrayOfPointersGemm,
|
||||||
|
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||||
|
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
|
||||||
|
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||||
|
|
||||||
|
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
|
||||||
|
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
|
||||||
|
|
||||||
|
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||||
|
using StrideA = cute::conditional_t<
|
||||||
|
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
|
||||||
|
GmemLayoutATag_,
|
||||||
|
TagToStrideA_t<GmemLayoutATag>>;
|
||||||
|
using StrideB = cute::conditional_t<
|
||||||
|
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
|
||||||
|
GmemLayoutBTag_,
|
||||||
|
TagToStrideB_t<GmemLayoutBTag>>;
|
||||||
|
|
||||||
|
using CollectiveOp = CollectiveMmaArrayMixedInput<
|
||||||
|
DispatchPolicy,
|
||||||
|
TileShape_MNK,
|
||||||
|
ElementPairA,
|
||||||
|
StrideA,
|
||||||
|
ElementPairB,
|
||||||
|
StrideB,
|
||||||
|
TiledMma,
|
||||||
|
GmemTiledCopyA,
|
||||||
|
SmemLayoutAtomA,
|
||||||
|
SmemCopyAtomA,
|
||||||
|
cute::identity,
|
||||||
|
GmemTiledCopyB,
|
||||||
|
SmemLayoutAtomB,
|
||||||
|
SmemCopyAtomB,
|
||||||
|
cute::identity>;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
class ArchTag,
|
||||||
|
class OpClass,
|
||||||
|
class ElementA,
|
||||||
|
class GmemLayoutA,
|
||||||
|
int AlignmentA,
|
||||||
|
class ElementB,
|
||||||
|
class GmemLayoutB,
|
||||||
|
int AlignmentB,
|
||||||
|
class ElementAccumulator,
|
||||||
|
class TileShape_MNK,
|
||||||
|
class ClusterShape_MNK,
|
||||||
|
class StageCountType,
|
||||||
|
class KernelScheduleType,
|
||||||
|
class Enable = void>
|
||||||
|
struct CollectiveBuilderMixedInput {
|
||||||
|
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/detail/dependent_false.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
class DispatchPolicy,
|
||||||
|
class TileShape,
|
||||||
|
class ElementA,
|
||||||
|
class StrideA,
|
||||||
|
class ElementB,
|
||||||
|
class StrideB,
|
||||||
|
class TiledMma,
|
||||||
|
class GmemTiledCopyA,
|
||||||
|
class SmemLayoutAtomA,
|
||||||
|
class SmemCopyAtomA,
|
||||||
|
class TransformA,
|
||||||
|
class GmemTiledCopyB,
|
||||||
|
class SmemLayoutAtomB,
|
||||||
|
class SmemCopyAtomB,
|
||||||
|
class TransformB>
|
||||||
|
struct CollectiveMmaArrayMixedInput {
|
||||||
|
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
File diff suppressed because it is too large
Load Diff
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
int32_t get_sm_version_num() {
|
||||||
|
int32_t major_capability, minor_capability;
|
||||||
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
|
||||||
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
|
||||||
|
int32_t version_num = major_capability * 10 + minor_capability;
|
||||||
|
return version_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_w4a8_moe_mm_sm90(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk);
|
||||||
|
|
||||||
|
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||||
|
const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation,
|
||||||
|
torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
|
|
||||||
|
void cutlass_w4a8_moe_mm(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk) {
|
||||||
|
cutlass_w4a8_moe_mm_sm90(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size,
|
||||||
|
topk);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_w4a8_moe_mm_data(
|
||||||
|
const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation,
|
||||||
|
torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
get_cutlass_w4a8_moe_mm_data_caller(
|
||||||
|
topk_ids,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
input_permutation,
|
||||||
|
output_permutation,
|
||||||
|
num_experts,
|
||||||
|
n,
|
||||||
|
k);
|
||||||
|
return;
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass/bfloat16.h"
|
||||||
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
|
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
|
||||||
|
__global__ void int4_fp8_get_group_gemm_starts(
|
||||||
|
int32_t* expert_offsets,
|
||||||
|
ElementA** a_offsets,
|
||||||
|
ElementB** b_offsets,
|
||||||
|
ElementC** out_offsets,
|
||||||
|
ElementAccumulator** a_scales_offsets,
|
||||||
|
cutlass::bfloat16_t** b_scales_offsets,
|
||||||
|
ElementA* a_base_as_int,
|
||||||
|
ElementB* b_base_as_int,
|
||||||
|
ElementC* out_base_as_int,
|
||||||
|
ElementAccumulator* a_scales_base_as_int,
|
||||||
|
cutlass::bfloat16_t* b_scales_base_as_int,
|
||||||
|
int64_t n,
|
||||||
|
int64_t k,
|
||||||
|
bool per_act_token,
|
||||||
|
bool per_out_ch) {
|
||||||
|
int expert_id = threadIdx.x;
|
||||||
|
int32_t expert_offset = expert_offsets[expert_id];
|
||||||
|
|
||||||
|
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||||
|
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
|
||||||
|
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||||
|
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||||
|
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||||
|
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||||
|
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
|
||||||
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
|
||||||
|
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||||
|
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||||
|
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
|
||||||
|
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||||
|
static_cast<float*>(a_scales.data_ptr()), \
|
||||||
|
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
|
||||||
|
out_tensors.size(1), \
|
||||||
|
a_tensors.size(1), \
|
||||||
|
per_act_token, \
|
||||||
|
per_out_ch); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void run_int4_fp8_get_group_gemm_starts(
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor& a_ptrs,
|
||||||
|
torch::Tensor& b_ptrs,
|
||||||
|
torch::Tensor& out_ptrs,
|
||||||
|
torch::Tensor& a_scales_ptrs,
|
||||||
|
torch::Tensor& b_scales_ptrs,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor& out_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
|
bool per_out_ch = b_scales.numel() != num_experts;
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||||
|
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "w4a8_grouped_mm_c3x.cuh"
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||||
|
|
||||||
|
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||||
|
|
||||||
|
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
|
||||||
|
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||||
|
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||||
|
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||||
|
\
|
||||||
|
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
|
||||||
|
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||||
|
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||||
|
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||||
|
\
|
||||||
|
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||||
|
};
|
||||||
|
|
||||||
|
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
|
||||||
|
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
|
||||||
|
|
||||||
|
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
|
||||||
|
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
|
||||||
|
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
|
||||||
|
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
|
||||||
|
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
|
||||||
|
|
||||||
|
void dispatch_w4a8_moe_mm_sm90(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk) {
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||||
|
|
||||||
|
uint32_t const m = a_tensors.size(0) / topk;
|
||||||
|
uint32_t const n = d_tensors.size(1);
|
||||||
|
uint32_t const k = a_tensors.size(1);
|
||||||
|
|
||||||
|
if (n == 4096 && k == 7168) {
|
||||||
|
// group gemm 1
|
||||||
|
if (m <= 4) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else if (m <= 16) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else if (m <= 256) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else if (m <= 1024) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
}
|
||||||
|
} else if (n == 7168 && k == 2048) {
|
||||||
|
// group gemm 2
|
||||||
|
if (m <= 8) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else if (m <= 512) {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
} else {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||||
|
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void cutlass_w4a8_moe_mm_sm90(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk) {
|
||||||
|
dispatch_w4a8_moe_mm_sm90(
|
||||||
|
d_tensors,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size,
|
||||||
|
topk);
|
||||||
|
}
|
||||||
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file w4a8_grouped_mm_c3x.cuh
|
||||||
|
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
|
||||||
|
* precision
|
||||||
|
*
|
||||||
|
* This file implements a grouped GEMM operation that multiplies FP8 matrices
|
||||||
|
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
|
||||||
|
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
|
||||||
|
* Cores for mixed precision arithmetic.
|
||||||
|
*
|
||||||
|
* Key features:
|
||||||
|
* - Supports grouped GEMM operations with multiple experts
|
||||||
|
* - Uses FP8 (e4m3) for matrix A
|
||||||
|
* - Uses INT4 quantization for matrix B with per-block scaling
|
||||||
|
* - Implements preprocessing for INT4 encoding and scale packing
|
||||||
|
* - Optimized for Hopper architecture with Tensor Core operations
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
|
||||||
|
#include "w4a8_get_group_starts.cuh"
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Type definitions
|
||||||
|
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||||
|
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||||
|
using ElementAccumulator = float; // Accumulator type
|
||||||
|
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||||
|
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
|
||||||
|
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||||
|
using ElementD = ElementC; // Default output type (FP16)
|
||||||
|
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||||
|
|
||||||
|
// Architecture-specific configurations
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
// constexpr int TileShapeK = 512;
|
||||||
|
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
|
||||||
|
// using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
// Layout configurations
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD = LayoutC;
|
||||||
|
|
||||||
|
// Transposed layouts
|
||||||
|
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
|
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
|
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||||
|
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
|
|
||||||
|
// Alignments
|
||||||
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
|
||||||
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
|
||||||
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||||
|
struct cutlass_3x_w4a8_group_gemm {
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
TileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementC,
|
||||||
|
LayoutC_Transpose*,
|
||||||
|
AlignmentC,
|
||||||
|
ElementD,
|
||||||
|
LayoutD_Transpose*,
|
||||||
|
AlignmentD,
|
||||||
|
EpilogueSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
cute::tuple<QuantType, ElementScalePacked>,
|
||||||
|
LayoutB_Transpose*,
|
||||||
|
AlignmentB,
|
||||||
|
MmaType,
|
||||||
|
LayoutA_Transpose*,
|
||||||
|
AlignmentA,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
// Define the final kernel and GEMM operation types
|
||||||
|
using GemmKernelScaleOnly =
|
||||||
|
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
|
||||||
|
|
||||||
|
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||||
|
|
||||||
|
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||||
|
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||||
|
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||||
|
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||||
|
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
|
||||||
|
*
|
||||||
|
* This function performs multiple GEMM operations in parallel where each
|
||||||
|
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
|
||||||
|
* applying per-channel scaling factors. It's designed for efficient execution
|
||||||
|
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
|
||||||
|
* mixed precision arithmetic.
|
||||||
|
*
|
||||||
|
* The function includes preprocessing steps for both INT4 tensors and scale
|
||||||
|
* factors to ensure optimal performance and correct operation.
|
||||||
|
*
|
||||||
|
* @param d_tensors Output tensor D with shape [total_m, total_n]
|
||||||
|
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
|
||||||
|
* [total_m, K]
|
||||||
|
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
|
||||||
|
* shape [E, N, K/2]
|
||||||
|
* @param a_scales Tensor containing A matrix scale factors
|
||||||
|
* @param b_scales Tensor containing B matrix scale factors with shape [E,
|
||||||
|
* K//512, N*4]
|
||||||
|
* @param expert_offsets Tensor containing expert offsets for determining group
|
||||||
|
* boundaries (int32)
|
||||||
|
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
|
||||||
|
* 3] (M, N, K for each group) (int32)
|
||||||
|
* @param a_strides Stride information for A tensors
|
||||||
|
* @param b_strides Stride information for B tensors
|
||||||
|
* @param d_strides Stride information for D tensors
|
||||||
|
* @param s_strides Stride information for scale tensors
|
||||||
|
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
|
||||||
|
*/
|
||||||
|
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_w4a8_group_gemm_caller(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size) {
|
||||||
|
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||||
|
using Args = typename Gemm::GemmScaleOnly::Arguments;
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
|
bool per_out_ch = b_scales.numel() != num_experts;
|
||||||
|
|
||||||
|
// Check inputs
|
||||||
|
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
|
||||||
|
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
|
||||||
|
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
|
||||||
|
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
|
||||||
|
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
|
||||||
|
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||||
|
|
||||||
|
// Check tensor shapes
|
||||||
|
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
|
||||||
|
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
|
||||||
|
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
|
||||||
|
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
|
||||||
|
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
|
||||||
|
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
|
||||||
|
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
|
||||||
|
|
||||||
|
// Check tensor types
|
||||||
|
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
|
||||||
|
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
|
||||||
|
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
|
||||||
|
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||||
|
|
||||||
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
hw_info.device_id = a_tensors.device().index();
|
||||||
|
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||||
|
|
||||||
|
Args arguments;
|
||||||
|
decltype(arguments.epilogue.thread) fusion_args;
|
||||||
|
fusion_args.alpha = 1.0f;
|
||||||
|
fusion_args.beta = 0;
|
||||||
|
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||||
|
;
|
||||||
|
fusion_args.beta_ptr = nullptr;
|
||||||
|
fusion_args.alpha_ptr_array = nullptr;
|
||||||
|
fusion_args.beta_ptr_array = nullptr;
|
||||||
|
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||||
|
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||||
|
|
||||||
|
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||||
|
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||||
|
|
||||||
|
run_int4_fp8_get_group_gemm_starts(
|
||||||
|
expert_offsets,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a_tensors,
|
||||||
|
b_tensors,
|
||||||
|
d_tensors,
|
||||||
|
a_scales,
|
||||||
|
b_scales);
|
||||||
|
|
||||||
|
arguments = Args{
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||||
|
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||||
|
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
|
||||||
|
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
|
||||||
|
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||||
|
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
|
||||||
|
static_cast<int>(chunk_size)},
|
||||||
|
{fusion_args,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
|
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
|
||||||
|
hw_info};
|
||||||
|
|
||||||
|
// Instantiate and run GEMM
|
||||||
|
typename Gemm::GemmScaleOnly gemm;
|
||||||
|
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
|
||||||
|
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
|
||||||
|
cutlass::Status status = gemm.can_implement(arguments);
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
TORCH_CHECK(false, "GEMM implementation not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
TORCH_CHECK(false, "GEMM initialization failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
status = gemm.run(stream);
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
TORCH_CHECK(false, "GEMM execution failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||||
|
|
||||||
|
__global__ void compute_problem_sizes_w4a8(
|
||||||
|
const int32_t* __restrict__ topk_ids,
|
||||||
|
int32_t* problem_sizes1,
|
||||||
|
int32_t* problem_sizes2,
|
||||||
|
int32_t* atomic_buffer,
|
||||||
|
const int topk_length,
|
||||||
|
const int n,
|
||||||
|
const int k) {
|
||||||
|
int expert_id = blockIdx.x;
|
||||||
|
|
||||||
|
int occurrences = 0;
|
||||||
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||||
|
occurrences += (topk_ids[i] == expert_id);
|
||||||
|
}
|
||||||
|
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int final_occurrences = atomic_buffer[expert_id];
|
||||||
|
problem_sizes1[expert_id * 3] = 2 * n;
|
||||||
|
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||||
|
problem_sizes1[expert_id * 3 + 2] = k;
|
||||||
|
problem_sizes2[expert_id * 3] = k;
|
||||||
|
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||||
|
problem_sizes2[expert_id * 3 + 2] = n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void compute_expert_offsets_w4a8(
|
||||||
|
const int32_t* __restrict__ problem_sizes1,
|
||||||
|
int32_t* expert_offsets,
|
||||||
|
int32_t* atomic_buffer,
|
||||||
|
const int num_experts) {
|
||||||
|
int32_t tot_offset = 0;
|
||||||
|
expert_offsets[0] = 0;
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
atomic_buffer[i] = tot_offset;
|
||||||
|
tot_offset += problem_sizes1[i * 3 + 1];
|
||||||
|
expert_offsets[i + 1] = tot_offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||||
|
const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation,
|
||||||
|
torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||||
|
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||||
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||||
|
|
||||||
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||||
|
topk_ids.numel(),
|
||||||
|
n,
|
||||||
|
k);
|
||||||
|
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||||
|
num_experts);
|
||||||
|
}
|
||||||
@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct(
|
|||||||
int64_t page_size,
|
int64_t page_size,
|
||||||
int64_t num_layers);
|
int64_t num_layers);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/moe/cutlass_moe/w4a8
|
||||||
|
*/
|
||||||
|
void get_cutlass_w4a8_moe_mm_data(
|
||||||
|
const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation,
|
||||||
|
torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
|
|
||||||
|
void cutlass_w4a8_moe_mm(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from sgl_kernel.attention import (
|
|||||||
merge_state,
|
merge_state,
|
||||||
merge_state_v2,
|
merge_state_v2,
|
||||||
)
|
)
|
||||||
|
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||||
from sgl_kernel.elementwise import (
|
from sgl_kernel.elementwise import (
|
||||||
apply_rope_with_cos_sin_cache_inplace,
|
apply_rope_with_cos_sin_cache_inplace,
|
||||||
fused_add_rmsnorm,
|
fused_add_rmsnorm,
|
||||||
|
|||||||
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_cutlass_w4a8_moe_mm_data(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes1: torch.Tensor,
|
||||||
|
problem_sizes2: torch.Tensor,
|
||||||
|
input_permutation: torch.Tensor,
|
||||||
|
output_permutation: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||||
|
used in CUTLASS-based fused MoE.
|
||||||
|
|
||||||
|
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||||
|
compute:
|
||||||
|
- expert_offsets: Indices that mark at which token index each expert begins
|
||||||
|
its computation after the input is sorted with
|
||||||
|
input_permutation. The number of tokens computed with
|
||||||
|
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||||
|
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||||
|
multiplication in two grouped MMs used in
|
||||||
|
the fused MoE operation.
|
||||||
|
- input_permutation: Permutation that must be used to shuffle the input
|
||||||
|
before executing the MMs.
|
||||||
|
- output_permutation: Permutation that must be used to shuffle the output
|
||||||
|
after executing the MMs.
|
||||||
|
"""
|
||||||
|
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
||||||
|
topk_ids,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
input_permutation,
|
||||||
|
output_permutation,
|
||||||
|
num_experts,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w4a8_moe_mm(
|
||||||
|
d: torch.Tensor,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a_scales: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
experts_offsets: torch.tensor,
|
||||||
|
problem_sizes: torch.tensor,
|
||||||
|
a_strides: torch.tensor,
|
||||||
|
b_strides: torch.tensor,
|
||||||
|
d_strides: torch.tensor,
|
||||||
|
s_strides: torch.tensor,
|
||||||
|
chunk_size: int = 128,
|
||||||
|
topk: int = 8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
||||||
|
|
||||||
|
This function executes multiple GEMM operations in parallel, which is useful for
|
||||||
|
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
||||||
|
experts. The implementation leverages NVIDIA Hopper architecture features for
|
||||||
|
optimal performance with quantized weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d: Output matrices of shape [total_m, total_n]
|
||||||
|
a: Activation matrices in FP8 (float_e4m3_t) format
|
||||||
|
Each tensor should be of shape [total_m, K] in row-major layout
|
||||||
|
b: Weight matrices in packed int4 format
|
||||||
|
Each tensor should be of shape [E, N, K//2] in column-major layout
|
||||||
|
where each byte contains two 4-bit integers
|
||||||
|
a_scales: Scale factors for the inputs
|
||||||
|
b_scales: Scale factors for the quantized weights
|
||||||
|
Each tensor should be of shape [E, K//512, N*8]
|
||||||
|
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
||||||
|
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
||||||
|
a_strides: Strides information for A matrices
|
||||||
|
b_strides: Strides information for B matrices
|
||||||
|
d_strides: Strides information for D matrices
|
||||||
|
s_strides: Strides information for b_scales matrices
|
||||||
|
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- All tensors must be on a CUDA device
|
||||||
|
- Requires an NVIDIA Hopper GPU (H100)
|
||||||
|
- A tensors must be in float8_e4m3fn format
|
||||||
|
- B tensors must contain packed int4 values (stored as int8)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The function computes: D = (A * (B * scales))
|
||||||
|
for each group of tensors in parallel
|
||||||
|
"""
|
||||||
|
|
||||||
|
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
||||||
|
d,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
a_scales,
|
||||||
|
b_scales,
|
||||||
|
experts_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
d_strides,
|
||||||
|
s_strides,
|
||||||
|
chunk_size,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
260
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
260
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import cutlass_w4a8_moe_mm
|
||||||
|
|
||||||
|
|
||||||
|
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||||
|
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"the last dim size of int4_values_interleaved tensor must be even."
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||||
|
|
||||||
|
low_nibbles = input_tensor_int8[..., 0::2]
|
||||||
|
high_nibbles = input_tensor_int8[..., 1::2]
|
||||||
|
|
||||||
|
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||||
|
|
||||||
|
return packed_tensor.to(torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||||
|
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||||
|
|
||||||
|
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||||
|
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||||
|
w_q = w_q.contiguous()
|
||||||
|
|
||||||
|
scale_interleaved = ref_scale.reshape(
|
||||||
|
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||||
|
) # [E, N, K/4, 4]
|
||||||
|
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||||
|
scale_interleaved = scale_interleaved.reshape(
|
||||||
|
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||||
|
) # [E, K/4, N*4]
|
||||||
|
w_scale = scale_interleaved.contiguous()
|
||||||
|
|
||||||
|
return w_q, w_scale
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||||
|
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||||
|
# Test parameters
|
||||||
|
num_experts = 1
|
||||||
|
m = batch_size # batch size
|
||||||
|
k = 512 # input dimension
|
||||||
|
n = 1024 # output dimension
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = "cuda"
|
||||||
|
debug = False
|
||||||
|
|
||||||
|
print(f"\nTesting with batch_size={batch_size}")
|
||||||
|
|
||||||
|
# Create input tensors with ones
|
||||||
|
if debug:
|
||||||
|
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
||||||
|
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||||
|
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||||
|
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
a = torch.randn(m, k, dtype=dtype, device=device)
|
||||||
|
ref_w = torch.randint(
|
||||||
|
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||||
|
)
|
||||||
|
affine_coeff = 0.005
|
||||||
|
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||||
|
ref_w_scale = (
|
||||||
|
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
|
* affine_coeff
|
||||||
|
)
|
||||||
|
|
||||||
|
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||||
|
|
||||||
|
# Create expert offsets and problem sizes
|
||||||
|
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
|
||||||
|
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||||
|
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||||
|
b_strides = a_strides
|
||||||
|
s_strides = c_strides
|
||||||
|
|
||||||
|
# Quantize input
|
||||||
|
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
|
||||||
|
|
||||||
|
# Create output tensor
|
||||||
|
c = torch.empty((m, n), dtype=torch.float16, device=device)
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c,
|
||||||
|
a_q,
|
||||||
|
w,
|
||||||
|
a_scale,
|
||||||
|
w_scale,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
c_strides,
|
||||||
|
s_strides,
|
||||||
|
128,
|
||||||
|
8,
|
||||||
|
)
|
||||||
|
c = c.to(dtype)
|
||||||
|
|
||||||
|
# Reference implementation
|
||||||
|
experts_selection_result = torch.full((m,), 0)
|
||||||
|
c_ref = ref_grouped_gemm(
|
||||||
|
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
try:
|
||||||
|
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||||
|
except AssertionError as e:
|
||||||
|
# torch.set_printoptions(threshold=10_000)
|
||||||
|
print(f" FAILURE: tensors are NOT close.")
|
||||||
|
print(f" Ref tensor: {c_ref.flatten()}")
|
||||||
|
print(f" Cutlass tensor: {c.flatten()}")
|
||||||
|
print(
|
||||||
|
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||||
|
)
|
||||||
|
print(f" AssertionError: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
||||||
|
@pytest.mark.parametrize("k", [512, 1024])
|
||||||
|
@pytest.mark.parametrize("n", [1024, 2048])
|
||||||
|
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||||
|
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = "cuda"
|
||||||
|
debug = False
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
||||||
|
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||||
|
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||||
|
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
||||||
|
ref_w = torch.randint(
|
||||||
|
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||||
|
)
|
||||||
|
affine_coeff = 0.005
|
||||||
|
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||||
|
ref_w_scale = (
|
||||||
|
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
|
* affine_coeff
|
||||||
|
)
|
||||||
|
|
||||||
|
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||||
|
|
||||||
|
# random select experts
|
||||||
|
experts_selection_result = torch.randint(
|
||||||
|
0, num_experts, (batch_size,), device=device
|
||||||
|
)
|
||||||
|
permutation = torch.argsort(experts_selection_result)
|
||||||
|
expert_token_counts = torch.bincount(
|
||||||
|
experts_selection_result, minlength=num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create problem sizes and offsets for active experts
|
||||||
|
problem_sizes = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
problem_sizes.append([n, expert_token_counts[i].item(), k])
|
||||||
|
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
expert_offsets = []
|
||||||
|
offset = 0
|
||||||
|
for i in range(num_experts):
|
||||||
|
expert_offsets.append(offset)
|
||||||
|
offset += problem_sizes[i][1].item()
|
||||||
|
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# Permute input and quantize
|
||||||
|
a_perm = a[permutation]
|
||||||
|
a_q_perm = (
|
||||||
|
torch.clamp((a_perm / a_scale), -448.0, 448.0)
|
||||||
|
.to(torch.float8_e4m3fn)
|
||||||
|
.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create stride tensors
|
||||||
|
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||||
|
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||||
|
b_strides = a_strides
|
||||||
|
s_strides = c_strides
|
||||||
|
|
||||||
|
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c_perm,
|
||||||
|
a_q_perm,
|
||||||
|
w,
|
||||||
|
a_scale,
|
||||||
|
w_scale,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
c_strides,
|
||||||
|
s_strides,
|
||||||
|
128,
|
||||||
|
8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Un-permute the result
|
||||||
|
c = torch.empty_like(c_perm)
|
||||||
|
c[permutation] = c_perm
|
||||||
|
c = c.to(dtype)
|
||||||
|
|
||||||
|
c_ref = ref_grouped_gemm(
|
||||||
|
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
try:
|
||||||
|
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f" FAILURE: tensors are NOT close.")
|
||||||
|
print(
|
||||||
|
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||||
|
)
|
||||||
|
print(f" AssertionError: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
c_ref = torch.zeros_like(c)
|
||||||
|
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
|
||||||
|
for i in range(num_experts):
|
||||||
|
token_idx = torch.where(experts_selection_result == i)[0]
|
||||||
|
if len(token_idx) == 0:
|
||||||
|
continue
|
||||||
|
a = a_q[token_idx]
|
||||||
|
|
||||||
|
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
|
||||||
|
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
|
||||||
|
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
|
||||||
|
c = c.to(dtype)
|
||||||
|
c_ref[token_idx] = c.to(dtype)
|
||||||
|
|
||||||
|
return c_ref
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user