[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com> Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
@@ -249,6 +249,9 @@ set(SOURCES
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
"csrc/common_extension.cc"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
|
||||
|
||||
@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"int num_layers) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
m.def(
|
||||
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()");
|
||||
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||
|
||||
m.def(
|
||||
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||
" int chunk_size, int topk) -> ()");
|
||||
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,482 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective::detail {
|
||||
|
||||
template <class Collective>
|
||||
struct MixedGroupedGemmInputUtils {
|
||||
private:
|
||||
using KernelSchedule = typename Collective::KernelSchedule;
|
||||
using ConversionMode = typename Collective::ConversionMode;
|
||||
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||
using SmemLayoutScale = typename Collective::SmemLayoutScale;
|
||||
using SwappedElementA = typename Collective::SwappedElementA;
|
||||
using SwappedElementB = typename Collective::SwappedElementB;
|
||||
using RealSwappedElementA = typename Collective::RealSwappedElementA;
|
||||
using RealSwappedElementB = typename Collective::RealSwappedElementB;
|
||||
using ElementScale = typename Collective::ElementScale;
|
||||
using ElementZero = typename Collective::ElementZero;
|
||||
using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale;
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
static constexpr auto ModeHasScales = Collective::ModeHasScales;
|
||||
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
|
||||
|
||||
public:
|
||||
static constexpr auto elements_per_smem_scale() {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return 0;
|
||||
} else if constexpr (ModeHasScales) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto elements_per_smem_zero() {
|
||||
if constexpr (
|
||||
KernelConversionMode == ConversionMode::DirectConvert ||
|
||||
KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return 0;
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
}
|
||||
|
||||
// These methods use some the public members of the class. For that reason, we define them after the public section.
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_mk() {
|
||||
return cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementA>));
|
||||
}
|
||||
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_nk() {
|
||||
return cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementB>));
|
||||
}
|
||||
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_extra() {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return 0;
|
||||
} else if constexpr (ModeHasScales) {
|
||||
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||
static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
|
||||
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return scale_tx_bytes;
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
// Scale and zero share smem layout
|
||||
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||
static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
|
||||
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
|
||||
return scale_tx_bytes + zero_tx_bytes;
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities to copy A and extra inputs from smem to RF
|
||||
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
|
||||
CUTLASS_DEVICE static void copy_tensors_MK(
|
||||
SmemTiledCopyA const& smem_tiled_copy_A,
|
||||
TensorASmemView const& tCsA,
|
||||
TensorACopyView& tCrA_copy_view,
|
||||
cute::tuple<Ts...> const& partitioned_mma_extra_info,
|
||||
cute::tuple<Us...> const& tiled_copy_and_views,
|
||||
int k_block,
|
||||
int read_stage) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
|
||||
|
||||
if (k_block == 0) {
|
||||
// We are starting a new k-tile so copy the scale
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
} else if constexpr (ModeHasScales) {
|
||||
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
|
||||
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
|
||||
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
|
||||
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
// Nothing extra to do
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
|
||||
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
|
||||
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||
template <
|
||||
class EngineIn,
|
||||
class LayoutIn,
|
||||
class EngineOut,
|
||||
class LayoutOut,
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries
|
||||
Tensor<EngineIn, LayoutIn> const& src,
|
||||
Tensor<EngineOut, LayoutOut>&& dst,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||
lookup_table_convert(src, dst, scales_neg, scales_pos);
|
||||
}
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class EngineScale, class LayoutScale>
|
||||
CUTLASS_DEVICE static void lookup_table_convert(
|
||||
Tensor<EngineIn, LayoutIn> const& src,
|
||||
Tensor<EngineOut, LayoutOut>& dst,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||
constexpr int N = cute::cosize(LayoutIn{});
|
||||
static_assert(N == 4 || N == 8);
|
||||
static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale.");
|
||||
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
|
||||
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, N / 4, sizeof(DstArray)>;
|
||||
|
||||
// View the input as reg
|
||||
auto&& src_reg = cute::recast<uint32_t>(src)(0);
|
||||
auto&& r = cute::recast<RegArray>(dst)(0);
|
||||
|
||||
// Determines if to get from the signed or unsigned candidates
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
"}\n"
|
||||
: "=r"(sign)
|
||||
: "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut));
|
||||
sign = sign >> 1;
|
||||
|
||||
// Ignore sign bit when indexing into LUT
|
||||
uint32_t lut_idx = src_reg & 0x77777777;
|
||||
Tensor scales_neg_ = cute::filter(scales_neg);
|
||||
Tensor scales_pos_ = cute::filter(scales_pos);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) {
|
||||
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_(i));
|
||||
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_pos_(i));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 pos, neg ;\n"
|
||||
" prmt .b32 neg, %3, %4, %1 ;\n"
|
||||
" prmt .b32 pos, %5, %6, %1 ;\n"
|
||||
" prmt .b32 %0, pos, neg, %2 ;\n"
|
||||
"}\n"
|
||||
: "=r"(r[i])
|
||||
: "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities to dequantize A.
|
||||
template <class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) {
|
||||
static_assert(
|
||||
shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0,
|
||||
"At least 4 adjacent weights in a thread must share the same scale.");
|
||||
}
|
||||
template <class Engine, class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Tensor<Engine, Layout> const& tensor) {
|
||||
static_check_scale(flatten(Layout{}));
|
||||
}
|
||||
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock(
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineOut, LayoutOut>& tCrA_mma,
|
||||
cute::tuple<Ts...>& partitioned_extra_info,
|
||||
int const k_block) {
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||
// try to make the size of the first mode equal to 32bit
|
||||
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
} else if constexpr (UseScaleLookupTable) {
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
static_assert(
|
||||
is_same_v<RealSwappedElementA, cutlass::int4b_t>,
|
||||
"Lookup table only supports int4 being the quant type now.");
|
||||
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
|
||||
static_assert(
|
||||
num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
|
||||
|
||||
Tensor tCrS_neg = cute::get<1>(partitioned_extra_info);
|
||||
auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed
|
||||
Tensor scales_neg = tCrS_neg(_, _, k_block);
|
||||
Tensor scales_pos = tCrS_pos(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg));
|
||||
|
||||
static_check_scale(scales_neg);
|
||||
static_check_scale(scales_pos);
|
||||
Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int<NumValPerSrcReg>{}));
|
||||
Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if (k_block == 0) {
|
||||
Tensor scales_neg_vm_ = filter(scales_neg_vm);
|
||||
Tensor scales_pos_vm_ = filter(scales_pos_vm);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) {
|
||||
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
|
||||
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2>&>(scales_pos_vm_(i));
|
||||
constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3 .b32 %0, %2, %4, %5, %6;\n"
|
||||
" xor .b32 %1, %3, %5; \n"
|
||||
"}\n"
|
||||
: "=r"(scale_pos_[0]), "=r"(scale_pos_[1])
|
||||
: "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut));
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i));
|
||||
}
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
dst_vm(j, i) *= scales_vm(j, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), stage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
stage(j) *= scales_vm(j, i);
|
||||
}
|
||||
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||
Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), stage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i);
|
||||
}
|
||||
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "No A data is loaded.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void convert_A_kblock(
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, int const k_block) {
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||
// try to make the size of the first mode equal to 32bit
|
||||
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||
|
||||
// KernelConversionMode == ConversionMode::DirectConvert
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities for any additional inputs inside of the TMA load
|
||||
template <class Params, class TensorStorage, class... Ts>
|
||||
CUTLASS_DEVICE static auto partition_extra_tma_inputs(
|
||||
Params const& mainloop_params,
|
||||
cute::tuple<Ts...> const& load_inputs,
|
||||
TensorStorage& shared_tensors,
|
||||
uint2 const& cluster_local_block_id,
|
||||
int const m_coord,
|
||||
int const l_coord) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (ModeHasScales) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gS_mkl = get<2>(load_inputs);
|
||||
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
|
||||
Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tSgS, tSsS);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gZ_mkl = get<3>(load_inputs);
|
||||
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
|
||||
Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
|
||||
template <class ThreadMma, class TensorStorage>
|
||||
CUTLASS_DEVICE static auto
|
||||
partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (UseScaleLookupTable) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||
|
||||
return cute::make_tuple(tCsS, tCrS);
|
||||
} else if constexpr (ModeHasScales) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tCsS, tCrS);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ = make_tensor(
|
||||
make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
|
||||
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout());
|
||||
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the tiled copy and copy views for the extra inputs.
|
||||
template <class TiledMma, class... Ts>
|
||||
CUTLASS_DEVICE static auto retile_extra_mma_info(
|
||||
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info, int const warp_group_thread_idx) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (ModeHasScales) {
|
||||
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
|
||||
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
|
||||
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective::detail
|
||||
@@ -0,0 +1,278 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_RS
|
||||
template <
|
||||
class ElementA_,
|
||||
class GmemLayoutATag_,
|
||||
int AlignmentA,
|
||||
class ElementB_,
|
||||
class GmemLayoutBTag_,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType>
|
||||
struct CollectiveBuilderMixedInput<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA_,
|
||||
GmemLayoutATag_,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
GmemLayoutBTag_,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
|
||||
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
|
||||
// ConvertAndScale and ConvertAndScaleWithZero
|
||||
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
|
||||
// DirectConvert
|
||||
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
|
||||
private:
|
||||
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
|
||||
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
|
||||
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
|
||||
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
|
||||
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
|
||||
// Determine if mixed input types.
|
||||
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
|
||||
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
|
||||
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>;
|
||||
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
|
||||
|
||||
public:
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
|
||||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
|
||||
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||
|
||||
template <class T>
|
||||
static auto get_stride(T const& t) {
|
||||
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
|
||||
return t;
|
||||
} else {
|
||||
if constexpr (cute::is_pointer_v<T>) {
|
||||
return &cute::stride(*t);
|
||||
} else {
|
||||
return cute::stride(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
|
||||
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
|
||||
|
||||
using ElementPairA =
|
||||
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
|
||||
using ElementPairB =
|
||||
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
|
||||
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
|
||||
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(
|
||||
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
|
||||
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
|
||||
// operands.
|
||||
static constexpr bool SwapAB =
|
||||
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
|
||||
static constexpr bool IsWarpSpecializedTransposeB =
|
||||
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
|
||||
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
|
||||
|
||||
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
|
||||
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type.
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
// Handle mixed dtypes and MMA.
|
||||
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
|
||||
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
|
||||
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
|
||||
// Always the same for element B.
|
||||
using RealElementBMma = RealElementB;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
|
||||
"Mixed input GEMM does not support MN major layout except for 16bit");
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>,
|
||||
Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<
|
||||
RealElementAMma,
|
||||
RealElementBMma,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
GMMA::Major::K,
|
||||
GMMA::Major::K>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorA,
|
||||
ElementAMma,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorB,
|
||||
ElementBMma,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
|
||||
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
|
||||
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
|
||||
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
|
||||
|
||||
// Handle mixed dtype array GEMM's size of tensor map storage.
|
||||
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
Sm90ReducedSmemCapacityBytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{})
|
||||
: detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{}))
|
||||
: detail::compute_stage_count_or_override<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{});
|
||||
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsMixedInput,
|
||||
cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
|
||||
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
|
||||
|
||||
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||
using StrideA = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
|
||||
GmemLayoutATag_,
|
||||
TagToStrideA_t<GmemLayoutATag>>;
|
||||
using StrideB = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
|
||||
GmemLayoutBTag_,
|
||||
TagToStrideB_t<GmemLayoutBTag>>;
|
||||
|
||||
using CollectiveOp = CollectiveMmaArrayMixedInput<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementPairA,
|
||||
StrideA,
|
||||
ElementPairB,
|
||||
StrideB,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity>;
|
||||
|
||||
static_assert(
|
||||
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ArchTag,
|
||||
class OpClass,
|
||||
class ElementA,
|
||||
class GmemLayoutA,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutB,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType,
|
||||
class Enable = void>
|
||||
struct CollectiveBuilderMixedInput {
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class DispatchPolicy,
|
||||
class TileShape,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class TiledMma,
|
||||
class GmemTiledCopyA,
|
||||
class SmemLayoutAtomA,
|
||||
class SmemCopyAtomA,
|
||||
class TransformA,
|
||||
class GmemTiledCopyB,
|
||||
class SmemLayoutAtomB,
|
||||
class SmemCopyAtomB,
|
||||
class TransformB>
|
||||
struct CollectiveMmaArrayMixedInput {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
cutlass_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
return;
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
get_cutlass_w4a8_moe_mm_data_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k);
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void int4_fp8_get_group_gemm_starts(
|
||||
int32_t* expert_offsets,
|
||||
ElementA** a_offsets,
|
||||
ElementB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementAccumulator** a_scales_offsets,
|
||||
cutlass::bfloat16_t** b_scales_offsets,
|
||||
ElementA* a_base_as_int,
|
||||
ElementB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
cutlass::bfloat16_t* b_scales_base_as_int,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
|
||||
out_tensors.size(1), \
|
||||
a_tensors.size(1), \
|
||||
per_act_token, \
|
||||
per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_int4_fp8_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
@@ -0,0 +1,240 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "w4a8_grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
|
||||
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
|
||||
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
|
||||
|
||||
void dispatch_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
|
||||
uint32_t const m = a_tensors.size(0) / topk;
|
||||
uint32_t const n = d_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n == 4096 && k == 7168) {
|
||||
// group gemm 1
|
||||
if (m <= 4) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 7168 && k == 2048) {
|
||||
// group gemm 2
|
||||
if (m <= 8) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
dispatch_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
}
|
||||
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
@@ -0,0 +1,276 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @file w4a8_grouped_mm_c3x.cuh
|
||||
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
|
||||
* precision
|
||||
*
|
||||
* This file implements a grouped GEMM operation that multiplies FP8 matrices
|
||||
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
|
||||
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
|
||||
* Cores for mixed precision arithmetic.
|
||||
*
|
||||
* Key features:
|
||||
* - Supports grouped GEMM operations with multiple experts
|
||||
* - Uses FP8 (e4m3) for matrix A
|
||||
* - Uses INT4 quantization for matrix B with per-block scaling
|
||||
* - Implements preprocessing for INT4 encoding and scale packing
|
||||
* - Optimized for Hopper architecture with Tensor Core operations
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
|
||||
#include "w4a8_get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
// Type definitions
|
||||
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||
using ElementAccumulator = float; // Accumulator type
|
||||
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
|
||||
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||
using ElementD = ElementC; // Default output type (FP16)
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Architecture-specific configurations
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
// constexpr int TileShapeK = 512;
|
||||
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
|
||||
// using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
// Layout configurations
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Transposed layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
struct cutlass_3x_w4a8_group_gemm {
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
LayoutC_Transpose*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD_Transpose*,
|
||||
AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
cute::tuple<QuantType, ElementScalePacked>,
|
||||
LayoutB_Transpose*,
|
||||
AlignmentB,
|
||||
MmaType,
|
||||
LayoutA_Transpose*,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
// Define the final kernel and GEMM operation types
|
||||
using GemmKernelScaleOnly =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
|
||||
*
|
||||
* This function performs multiple GEMM operations in parallel where each
|
||||
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
|
||||
* applying per-channel scaling factors. It's designed for efficient execution
|
||||
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
|
||||
* mixed precision arithmetic.
|
||||
*
|
||||
* The function includes preprocessing steps for both INT4 tensors and scale
|
||||
* factors to ensure optimal performance and correct operation.
|
||||
*
|
||||
* @param d_tensors Output tensor D with shape [total_m, total_n]
|
||||
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
|
||||
* [total_m, K]
|
||||
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
|
||||
* shape [E, N, K/2]
|
||||
* @param a_scales Tensor containing A matrix scale factors
|
||||
* @param b_scales Tensor containing B matrix scale factors with shape [E,
|
||||
* K//512, N*4]
|
||||
* @param expert_offsets Tensor containing expert offsets for determining group
|
||||
* boundaries (int32)
|
||||
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
|
||||
* 3] (M, N, K for each group) (int32)
|
||||
* @param a_strides Stride information for A tensors
|
||||
* @param b_strides Stride information for B tensors
|
||||
* @param d_strides Stride information for D tensors
|
||||
* @param s_strides Stride information for scale tensors
|
||||
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
|
||||
*/
|
||||
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
template <typename Gemm>
|
||||
void cutlass_w4a8_group_gemm_caller(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size) {
|
||||
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
using Args = typename Gemm::GemmScaleOnly::Arguments;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
// Check inputs
|
||||
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
|
||||
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
|
||||
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
|
||||
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
|
||||
// Check tensor shapes
|
||||
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
|
||||
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
|
||||
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
|
||||
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
|
||||
|
||||
// Check tensor types
|
||||
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
|
||||
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
|
||||
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_tensors.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha = 1.0f;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||
;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
run_int4_fp8_get_group_gemm_starts(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
d_tensors,
|
||||
a_scales,
|
||||
b_scales);
|
||||
|
||||
arguments = Args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
|
||||
static_cast<int>(chunk_size)},
|
||||
{fusion_args,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
|
||||
hw_info};
|
||||
|
||||
// Instantiate and run GEMM
|
||||
typename Gemm::GemmScaleOnly gemm;
|
||||
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM implementation not supported");
|
||||
}
|
||||
|
||||
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM initialization failed");
|
||||
}
|
||||
|
||||
status = gemm.run(stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM execution failed");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
@@ -0,0 +1,79 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes_w4a8(
|
||||
const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets_w4a8(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3 + 1];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
}
|
||||
@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct(
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
@@ -19,6 +19,7 @@ from sgl_kernel.attention import (
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
|
||||
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_cutlass_w4a8_moe_mm_data(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor,
|
||||
output_permutation: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_mm(
|
||||
d: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
experts_offsets: torch.tensor,
|
||||
problem_sizes: torch.tensor,
|
||||
a_strides: torch.tensor,
|
||||
b_strides: torch.tensor,
|
||||
d_strides: torch.tensor,
|
||||
s_strides: torch.tensor,
|
||||
chunk_size: int = 128,
|
||||
topk: int = 8,
|
||||
):
|
||||
"""
|
||||
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
||||
|
||||
This function executes multiple GEMM operations in parallel, which is useful for
|
||||
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
||||
experts. The implementation leverages NVIDIA Hopper architecture features for
|
||||
optimal performance with quantized weights.
|
||||
|
||||
Args:
|
||||
d: Output matrices of shape [total_m, total_n]
|
||||
a: Activation matrices in FP8 (float_e4m3_t) format
|
||||
Each tensor should be of shape [total_m, K] in row-major layout
|
||||
b: Weight matrices in packed int4 format
|
||||
Each tensor should be of shape [E, N, K//2] in column-major layout
|
||||
where each byte contains two 4-bit integers
|
||||
a_scales: Scale factors for the inputs
|
||||
b_scales: Scale factors for the quantized weights
|
||||
Each tensor should be of shape [E, K//512, N*8]
|
||||
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
||||
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
||||
a_strides: Strides information for A matrices
|
||||
b_strides: Strides information for B matrices
|
||||
d_strides: Strides information for D matrices
|
||||
s_strides: Strides information for b_scales matrices
|
||||
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
||||
|
||||
Requirements:
|
||||
- All tensors must be on a CUDA device
|
||||
- Requires an NVIDIA Hopper GPU (H100)
|
||||
- A tensors must be in float8_e4m3fn format
|
||||
- B tensors must contain packed int4 values (stored as int8)
|
||||
|
||||
Note:
|
||||
The function computes: D = (A * (B * scales))
|
||||
for each group of tensors in parallel
|
||||
"""
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
a_scales,
|
||||
b_scales,
|
||||
experts_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk,
|
||||
)
|
||||
260
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
260
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import cutlass_w4a8_moe_mm
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||
raise ValueError(
|
||||
"the last dim size of int4_values_interleaved tensor must be even."
|
||||
)
|
||||
|
||||
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||
|
||||
low_nibbles = input_tensor_int8[..., 0::2]
|
||||
high_nibbles = input_tensor_int8[..., 1::2]
|
||||
|
||||
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||
|
||||
return packed_tensor.to(torch.int8)
|
||||
|
||||
|
||||
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||
|
||||
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
return w_q, w_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
# Test parameters
|
||||
num_experts = 1
|
||||
m = batch_size # batch size
|
||||
k = 512 # input dimension
|
||||
n = 1024 # output dimension
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
debug = False
|
||||
|
||||
print(f"\nTesting with batch_size={batch_size}")
|
||||
|
||||
# Create input tensors with ones
|
||||
if debug:
|
||||
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=dtype, device=device)
|
||||
ref_w = torch.randint(
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||
|
||||
# Create expert offsets and problem sizes
|
||||
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
|
||||
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
|
||||
|
||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||
b_strides = a_strides
|
||||
s_strides = c_strides
|
||||
|
||||
# Quantize input
|
||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
|
||||
|
||||
# Create output tensor
|
||||
c = torch.empty((m, n), dtype=torch.float16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c,
|
||||
a_q,
|
||||
w,
|
||||
a_scale,
|
||||
w_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
s_strides,
|
||||
128,
|
||||
8,
|
||||
)
|
||||
c = c.to(dtype)
|
||||
|
||||
# Reference implementation
|
||||
experts_selection_result = torch.full((m,), 0)
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
try:
|
||||
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||
except AssertionError as e:
|
||||
# torch.set_printoptions(threshold=10_000)
|
||||
print(f" FAILURE: tensors are NOT close.")
|
||||
print(f" Ref tensor: {c_ref.flatten()}")
|
||||
print(f" Cutlass tensor: {c.flatten()}")
|
||||
print(
|
||||
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(
|
||||
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(f" AssertionError: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("k", [512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048])
|
||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
debug = False
|
||||
|
||||
print(
|
||||
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
|
||||
)
|
||||
|
||||
if debug:
|
||||
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
||||
ref_w = torch.randint(
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||
|
||||
# random select experts
|
||||
experts_selection_result = torch.randint(
|
||||
0, num_experts, (batch_size,), device=device
|
||||
)
|
||||
permutation = torch.argsort(experts_selection_result)
|
||||
expert_token_counts = torch.bincount(
|
||||
experts_selection_result, minlength=num_experts
|
||||
)
|
||||
|
||||
# Create problem sizes and offsets for active experts
|
||||
problem_sizes = []
|
||||
for i in range(num_experts):
|
||||
problem_sizes.append([n, expert_token_counts[i].item(), k])
|
||||
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
|
||||
|
||||
expert_offsets = []
|
||||
offset = 0
|
||||
for i in range(num_experts):
|
||||
expert_offsets.append(offset)
|
||||
offset += problem_sizes[i][1].item()
|
||||
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
||||
|
||||
# Permute input and quantize
|
||||
a_perm = a[permutation]
|
||||
a_q_perm = (
|
||||
torch.clamp((a_perm / a_scale), -448.0, 448.0)
|
||||
.to(torch.float8_e4m3fn)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
# Create stride tensors
|
||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||
b_strides = a_strides
|
||||
s_strides = c_strides
|
||||
|
||||
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c_perm,
|
||||
a_q_perm,
|
||||
w,
|
||||
a_scale,
|
||||
w_scale,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
s_strides,
|
||||
128,
|
||||
8,
|
||||
)
|
||||
|
||||
# Un-permute the result
|
||||
c = torch.empty_like(c_perm)
|
||||
c[permutation] = c_perm
|
||||
c = c.to(dtype)
|
||||
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
try:
|
||||
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||
except AssertionError as e:
|
||||
print(f" FAILURE: tensors are NOT close.")
|
||||
print(
|
||||
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(
|
||||
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(f" AssertionError: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
|
||||
dtype = torch.bfloat16
|
||||
c_ref = torch.zeros_like(c)
|
||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
|
||||
for i in range(num_experts):
|
||||
token_idx = torch.where(experts_selection_result == i)[0]
|
||||
if len(token_idx) == 0:
|
||||
continue
|
||||
a = a_q[token_idx]
|
||||
|
||||
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
|
||||
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
|
||||
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
|
||||
c = c.to(dtype)
|
||||
c_ref[token_idx] = c.to(dtype)
|
||||
|
||||
return c_ref
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user