adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

View File

@@ -0,0 +1,278 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/pipeline/sm90_pipeline.hpp"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_RS
template <
class ElementA_,
class GmemLayoutATag_,
int AlignmentA,
class ElementB_,
class GmemLayoutBTag_,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType>
struct CollectiveBuilderMixedInput<
arch::Sm90,
arch::OpClassTensorOp,
ElementA_,
GmemLayoutATag_,
AlignmentA,
ElementB_,
GmemLayoutBTag_,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
// ConvertAndScale and ConvertAndScaleWithZero
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
// DirectConvert
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
private:
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
// Determine if mixed input types.
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>;
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
public:
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
static_assert(
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
template <class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
return t;
} else {
if constexpr (cute::is_pointer_v<T>) {
return &cute::stride(*t);
} else {
return cute::stride(t);
}
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA =
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
using ElementPairB =
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
// operands.
static constexpr bool SwapAB =
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
static constexpr bool IsWarpSpecializedTransposeB =
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
// For fp32 types, map to tf32 MMA value type.
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
// Handle mixed dtypes and MMA.
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
// Always the same for element B.
using RealElementBMma = RealElementB;
static_assert(
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
"Mixed input GEMM does not support MN major layout except for 16bit");
using AtomLayoutMNK = cute::conditional_t<
cute::is_any_of_v<
KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>,
Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<
RealElementAMma,
RealElementBMma,
ElementAccumulator,
TileShape_MNK,
GMMA::Major::K,
GMMA::Major::K>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
GmmaMajorA,
ElementAMma,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
GmmaMajorB,
ElementBMma,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
// Handle mixed dtype array GEMM's size of tensor map storage.
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int PipelineStages =
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
Sm90ReducedSmemCapacityBytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{})
: detail::compute_stage_count_or_override_single_affine_transformed_input<
detail::sm90_smem_capacity_bytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{}))
: detail::compute_stage_count_or_override<
detail::sm90_smem_capacity_bytes,
ElementAMma,
ElementBMma,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsMixedInput,
cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
GmemLayoutATag_,
TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
GmemLayoutBTag_,
TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMmaArrayMixedInput<
DispatchPolicy,
TileShape_MNK,
ElementPairA,
StrideA,
ElementPairB,
StrideB,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity>;
static_assert(
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,52 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType,
class Enable = void>
struct CollectiveBuilderMixedInput {
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class DispatchPolicy,
class TileShape,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class TiledMma,
class GmemTiledCopyA,
class SmemLayoutAtomA,
class SmemCopyAtomA,
class TransformA,
class GmemTiledCopyB,
class SmemLayoutAtomB,
class SmemCopyAtomB,
class TransformB>
struct CollectiveMmaArrayMixedInput {
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,62 @@
// Adapted from
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
template <typename GemmKernel>
void cutlass_gemm_caller(
torch::Device device,
cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args, hw_info, scheduler};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(device);
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}

View File

@@ -0,0 +1,38 @@
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once
#include "cutlass/gemm/dispatch_policy.hpp"
namespace cutlass::gemm {
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template <int ScaleGranularityM = 0>
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template <
int Stages_,
class ClusterShape_ = Shape<_1, _1, _1>,
class KernelSchedule = KernelTmaWarpSpecialized,
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
};
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm

View File

@@ -0,0 +1,197 @@
// Adapted from
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
#pragma once
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
using namespace cute;
template <
typename SchedulerType,
typename OutType,
int GroupSizeM_,
int GroupSizeN_,
int GroupSizeK_,
int TileSizeM_ = 128,
class ClusterShape = Shape<_1, _2, _1>>
struct cutlass_3x_gemm_fp8_blockwise {
using GroupSizeM = Int<GroupSizeM_>;
using GroupSizeN = Int<GroupSizeN_>;
using GroupSizeK = Int<GroupSizeK_>;
using TileSizeM = Int<TileSizeM_>;
static_assert(TileSizeM_ % GroupSizeM_ == 0, "TileSizeM must be a multiple of GroupSizeM");
using ElementAB = cutlass::float_e4m3_t;
// A matrix configuration
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using ScaleTileShape = Shape<_1, _128, _128>;
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>; // Threadblock-level tile size
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
SchedulerType>;
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(
torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using ElementAB = typename Gemm::ElementAB;
using ElementA = ElementAB;
using ElementB = ElementAB;
using ElementD = typename Gemm::ElementD;
using ElementBlockScale = float;
using ScaleTileShape = Shape<_1, _128, _128>;
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
auto a_s_ptr = static_cast<ElementBlockScale*>(a_scales.data_ptr());
auto b_s_ptr = static_cast<ElementBlockScale*>(b_scales.data_ptr());
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
using StrideD = typename GemmKernel::StrideD;
using StrideC = typename GemmKernel::StrideC;
StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::TileSchedulerArguments scheduler;
static constexpr bool UsesStreamKScheduler =
cute::is_same_v<typename GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>;
if constexpr (UsesStreamKScheduler) {
using DecompositionMode =
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
using ReductionMode =
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
scheduler.decomposition_mode = DecompositionMode::StreamK;
scheduler.reduction_mode = ReductionMode::Nondeterministic;
}
cutlass_gemm_caller<GemmKernel>(a.device(), {m, n, k, 1}, mainloop_args, epilogue_args, scheduler);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(
torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<
cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
}
}

View File

@@ -0,0 +1,356 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat {
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK =
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args) {
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0) {
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST(
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes) {
if (!workspace) {
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result =
cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,492 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include <cutlass/complex.h>
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure
Arguments(
GemmCoord problem_size_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm),
problem_size(problem_size_),
batch_count(1),
ref_A(ref_A_),
ref_B(ref_B_),
ref_alpha_col(ref_alpha_col_),
ref_alpha_row(ref_alpha_row_),
ref_C(ref_C_),
ref_D(ref_D_),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_D(0),
epilogue_visitor(epilogue_visitor_) {}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void* ptr_A;
void* ptr_B;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
ElementC* ptr_C;
ElementC* ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
: problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_alpha_col(args.ref_alpha_col.layout()),
params_alpha_row(args.ref_alpha_col.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_alpha_col(args.ref_alpha_col.data()),
ptr_alpha_row(args.ref_alpha_row.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor) {
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK =
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) {
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const& args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
} else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
} else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
bool with_bias = true;
if (params.ptr_C == nullptr) {
with_bias = false;
}
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) {
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
run_kernel_(params, shared_storage);
} else {
CUTLASS_NOT_IMPLEMENTED();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage) {
run_kernel<ArchTag>(params, shared_storage);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////