[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com> Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_RS
|
||||
template <
|
||||
class ElementA_,
|
||||
class GmemLayoutATag_,
|
||||
int AlignmentA,
|
||||
class ElementB_,
|
||||
class GmemLayoutBTag_,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType>
|
||||
struct CollectiveBuilderMixedInput<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA_,
|
||||
GmemLayoutATag_,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
GmemLayoutBTag_,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
|
||||
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
|
||||
// ConvertAndScale and ConvertAndScaleWithZero
|
||||
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
|
||||
// DirectConvert
|
||||
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
|
||||
private:
|
||||
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
|
||||
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
|
||||
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
|
||||
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
|
||||
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
|
||||
// Determine if mixed input types.
|
||||
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
|
||||
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
|
||||
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>;
|
||||
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
|
||||
|
||||
public:
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
|
||||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
|
||||
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||
|
||||
template <class T>
|
||||
static auto get_stride(T const& t) {
|
||||
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
|
||||
return t;
|
||||
} else {
|
||||
if constexpr (cute::is_pointer_v<T>) {
|
||||
return &cute::stride(*t);
|
||||
} else {
|
||||
return cute::stride(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
|
||||
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
|
||||
|
||||
using ElementPairA =
|
||||
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
|
||||
using ElementPairB =
|
||||
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
|
||||
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
|
||||
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(
|
||||
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
|
||||
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
|
||||
// operands.
|
||||
static constexpr bool SwapAB =
|
||||
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
|
||||
static constexpr bool IsWarpSpecializedTransposeB =
|
||||
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
|
||||
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
|
||||
|
||||
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
|
||||
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type.
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
// Handle mixed dtypes and MMA.
|
||||
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
|
||||
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
|
||||
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
|
||||
// Always the same for element B.
|
||||
using RealElementBMma = RealElementB;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
|
||||
"Mixed input GEMM does not support MN major layout except for 16bit");
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>,
|
||||
Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<
|
||||
RealElementAMma,
|
||||
RealElementBMma,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
GMMA::Major::K,
|
||||
GMMA::Major::K>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorA,
|
||||
ElementAMma,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorB,
|
||||
ElementBMma,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
|
||||
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
|
||||
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
|
||||
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
|
||||
|
||||
// Handle mixed dtype array GEMM's size of tensor map storage.
|
||||
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
Sm90ReducedSmemCapacityBytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{})
|
||||
: detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{}))
|
||||
: detail::compute_stage_count_or_override<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{});
|
||||
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsMixedInput,
|
||||
cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
|
||||
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
|
||||
|
||||
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||
using StrideA = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
|
||||
GmemLayoutATag_,
|
||||
TagToStrideA_t<GmemLayoutATag>>;
|
||||
using StrideB = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
|
||||
GmemLayoutBTag_,
|
||||
TagToStrideB_t<GmemLayoutBTag>>;
|
||||
|
||||
using CollectiveOp = CollectiveMmaArrayMixedInput<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementPairA,
|
||||
StrideA,
|
||||
ElementPairB,
|
||||
StrideB,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity>;
|
||||
|
||||
static_assert(
|
||||
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ArchTag,
|
||||
class OpClass,
|
||||
class ElementA,
|
||||
class GmemLayoutA,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutB,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType,
|
||||
class Enable = void>
|
||||
struct CollectiveBuilderMixedInput {
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class DispatchPolicy,
|
||||
class TileShape,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class TiledMma,
|
||||
class GmemTiledCopyA,
|
||||
class SmemLayoutAtomA,
|
||||
class SmemCopyAtomA,
|
||||
class TransformA,
|
||||
class GmemTiledCopyB,
|
||||
class SmemLayoutAtomB,
|
||||
class SmemCopyAtomB,
|
||||
class TransformB>
|
||||
struct CollectiveMmaArrayMixedInput {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user