adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
21
sgl-kernel/csrc/cutlass_extensions/common.hpp
Normal file
21
sgl-kernel/csrc/cutlass_extensions/common.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,482 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective::detail {
|
||||
|
||||
template <class Collective>
|
||||
struct MixedGroupedGemmInputUtils {
|
||||
private:
|
||||
using KernelSchedule = typename Collective::KernelSchedule;
|
||||
using ConversionMode = typename Collective::ConversionMode;
|
||||
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||
using SmemLayoutScale = typename Collective::SmemLayoutScale;
|
||||
using SwappedElementA = typename Collective::SwappedElementA;
|
||||
using SwappedElementB = typename Collective::SwappedElementB;
|
||||
using RealSwappedElementA = typename Collective::RealSwappedElementA;
|
||||
using RealSwappedElementB = typename Collective::RealSwappedElementB;
|
||||
using ElementScale = typename Collective::ElementScale;
|
||||
using ElementZero = typename Collective::ElementZero;
|
||||
using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale;
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
static constexpr auto ModeHasScales = Collective::ModeHasScales;
|
||||
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
|
||||
|
||||
public:
|
||||
static constexpr auto elements_per_smem_scale() {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return 0;
|
||||
} else if constexpr (ModeHasScales) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto elements_per_smem_zero() {
|
||||
if constexpr (
|
||||
KernelConversionMode == ConversionMode::DirectConvert ||
|
||||
KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return 0;
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
return cute::cosize_v<SmemLayoutScale>;
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
|
||||
}
|
||||
}
|
||||
|
||||
// These methods use some the public members of the class. For that reason, we define them after the public section.
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_mk() {
|
||||
return cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementA>));
|
||||
}
|
||||
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_nk() {
|
||||
return cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementB>));
|
||||
}
|
||||
|
||||
static constexpr uint32_t compute_tma_transaction_bytes_extra() {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return 0;
|
||||
} else if constexpr (ModeHasScales) {
|
||||
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||
static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
|
||||
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return scale_tx_bytes;
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
// Scale and zero share smem layout
|
||||
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(
|
||||
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
|
||||
static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
|
||||
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
|
||||
return scale_tx_bytes + zero_tx_bytes;
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities to copy A and extra inputs from smem to RF
|
||||
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
|
||||
CUTLASS_DEVICE static void copy_tensors_MK(
|
||||
SmemTiledCopyA const& smem_tiled_copy_A,
|
||||
TensorASmemView const& tCsA,
|
||||
TensorACopyView& tCrA_copy_view,
|
||||
cute::tuple<Ts...> const& partitioned_mma_extra_info,
|
||||
cute::tuple<Us...> const& tiled_copy_and_views,
|
||||
int k_block,
|
||||
int read_stage) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
|
||||
|
||||
if (k_block == 0) {
|
||||
// We are starting a new k-tile so copy the scale
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
} else if constexpr (ModeHasScales) {
|
||||
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
|
||||
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
|
||||
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
|
||||
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
// Nothing extra to do
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
|
||||
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
|
||||
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||
template <
|
||||
class EngineIn,
|
||||
class LayoutIn,
|
||||
class EngineOut,
|
||||
class LayoutOut,
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries
|
||||
Tensor<EngineIn, LayoutIn> const& src,
|
||||
Tensor<EngineOut, LayoutOut>&& dst,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||
lookup_table_convert(src, dst, scales_neg, scales_pos);
|
||||
}
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class EngineScale, class LayoutScale>
|
||||
CUTLASS_DEVICE static void lookup_table_convert(
|
||||
Tensor<EngineIn, LayoutIn> const& src,
|
||||
Tensor<EngineOut, LayoutOut>& dst,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_neg,
|
||||
Tensor<EngineScale, LayoutScale> const& scales_pos) {
|
||||
constexpr int N = cute::cosize(LayoutIn{});
|
||||
static_assert(N == 4 || N == 8);
|
||||
static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale.");
|
||||
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
|
||||
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, N / 4, sizeof(DstArray)>;
|
||||
|
||||
// View the input as reg
|
||||
auto&& src_reg = cute::recast<uint32_t>(src)(0);
|
||||
auto&& r = cute::recast<RegArray>(dst)(0);
|
||||
|
||||
// Determines if to get from the signed or unsigned candidates
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
"}\n"
|
||||
: "=r"(sign)
|
||||
: "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut));
|
||||
sign = sign >> 1;
|
||||
|
||||
// Ignore sign bit when indexing into LUT
|
||||
uint32_t lut_idx = src_reg & 0x77777777;
|
||||
Tensor scales_neg_ = cute::filter(scales_neg);
|
||||
Tensor scales_pos_ = cute::filter(scales_pos);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) {
|
||||
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_(i));
|
||||
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_pos_(i));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 pos, neg ;\n"
|
||||
" prmt .b32 neg, %3, %4, %1 ;\n"
|
||||
" prmt .b32 pos, %5, %6, %1 ;\n"
|
||||
" prmt .b32 %0, pos, neg, %2 ;\n"
|
||||
"}\n"
|
||||
: "=r"(r[i])
|
||||
: "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities to dequantize A.
|
||||
template <class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) {
|
||||
static_assert(
|
||||
shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0,
|
||||
"At least 4 adjacent weights in a thread must share the same scale.");
|
||||
}
|
||||
template <class Engine, class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Tensor<Engine, Layout> const& tensor) {
|
||||
static_check_scale(flatten(Layout{}));
|
||||
}
|
||||
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock(
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineOut, LayoutOut>& tCrA_mma,
|
||||
cute::tuple<Ts...>& partitioned_extra_info,
|
||||
int const k_block) {
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||
// try to make the size of the first mode equal to 32bit
|
||||
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
} else if constexpr (UseScaleLookupTable) {
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
static_assert(
|
||||
is_same_v<RealSwappedElementA, cutlass::int4b_t>,
|
||||
"Lookup table only supports int4 being the quant type now.");
|
||||
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
|
||||
static_assert(
|
||||
num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
|
||||
|
||||
Tensor tCrS_neg = cute::get<1>(partitioned_extra_info);
|
||||
auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed
|
||||
Tensor scales_neg = tCrS_neg(_, _, k_block);
|
||||
Tensor scales_pos = tCrS_pos(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg));
|
||||
|
||||
static_check_scale(scales_neg);
|
||||
static_check_scale(scales_pos);
|
||||
Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int<NumValPerSrcReg>{}));
|
||||
Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if (k_block == 0) {
|
||||
Tensor scales_neg_vm_ = filter(scales_neg_vm);
|
||||
Tensor scales_pos_vm_ = filter(scales_pos_vm);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) {
|
||||
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
|
||||
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2>&>(scales_pos_vm_(i));
|
||||
constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3 .b32 %0, %2, %4, %5, %6;\n"
|
||||
" xor .b32 %1, %3, %5; \n"
|
||||
"}\n"
|
||||
: "=r"(scale_pos_[0]), "=r"(scale_pos_[1])
|
||||
: "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut));
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i));
|
||||
}
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
dst_vm(j, i) *= scales_vm(j, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), stage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
stage(j) *= scales_vm(j, i);
|
||||
}
|
||||
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
|
||||
Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
|
||||
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), stage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(dst_vm); ++j) {
|
||||
stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i);
|
||||
}
|
||||
LayoutAwareConvert(stage, dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "No A data is loaded.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void convert_A_kblock(
|
||||
Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, int const k_block) {
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
|
||||
// try to make the size of the first mode equal to 32bit
|
||||
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
|
||||
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
|
||||
|
||||
// KernelConversionMode == ConversionMode::DirectConvert
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i) {
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities for any additional inputs inside of the TMA load
|
||||
template <class Params, class TensorStorage, class... Ts>
|
||||
CUTLASS_DEVICE static auto partition_extra_tma_inputs(
|
||||
Params const& mainloop_params,
|
||||
cute::tuple<Ts...> const& load_inputs,
|
||||
TensorStorage& shared_tensors,
|
||||
uint2 const& cluster_local_block_id,
|
||||
int const m_coord,
|
||||
int const l_coord) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (ModeHasScales) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gS_mkl = get<2>(load_inputs);
|
||||
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
|
||||
Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tSgS, tSsS);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor gZ_mkl = get<3>(load_inputs);
|
||||
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
|
||||
Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
|
||||
template <class ThreadMma, class TensorStorage>
|
||||
CUTLASS_DEVICE static auto
|
||||
partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (UseScaleLookupTable) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||
|
||||
return cute::make_tuple(tCsS, tCrS);
|
||||
} else if constexpr (ModeHasScales) {
|
||||
Tensor sS =
|
||||
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsS = mma_thread_slice.partition_A(sS);
|
||||
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(tCsS, tCrS);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor sZ = make_tensor(
|
||||
make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
|
||||
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
|
||||
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout());
|
||||
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the tiled copy and copy views for the extra inputs.
|
||||
template <class TiledMma, class... Ts>
|
||||
CUTLASS_DEVICE static auto retile_extra_mma_info(
|
||||
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info, int const warp_group_thread_idx) {
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// nothing to do
|
||||
return cute::make_tuple();
|
||||
} else if constexpr (ModeHasScales) {
|
||||
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
|
||||
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
|
||||
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
|
||||
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective::detail
|
||||
@@ -0,0 +1,309 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename ScaleTileIterator_,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol {
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(batch_stride_alpha_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_) {}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise),
|
||||
batch_stride_alpha(args.batch_stride_alpha),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D) {}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const with_bias_;
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(
|
||||
Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D,
|
||||
bool with_bias,
|
||||
bool per_token_quant,
|
||||
bool per_channel_quant,
|
||||
AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col,
|
||||
typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
with_bias_(with_bias),
|
||||
per_token_quant_(per_token_quant),
|
||||
per_channel_quant_(per_channel_quant),
|
||||
ptr_alpha_row_(ptr_alpha_row),
|
||||
ptr_alpha_col_(ptr_alpha_col),
|
||||
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
extent_real_(problem_size_real) {
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (per_channel_quant_) {
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_) {
|
||||
int thread_offset_row =
|
||||
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) {
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_) {
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
} else {
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kElementsPerAccess> bias_converter;
|
||||
OutputVector bias = reinterpret_cast<OutputVector*>(&fragment_C_)[column_idx];
|
||||
result = bias_accumulator_(result, bias_converter(bias));
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment bias_accumulator_(ComputeFragment const& accum, ComputeFragment const& bias) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputVector::kElements; ++i) {
|
||||
result[i] = accum[i] + bias[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,278 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_RS
|
||||
template <
|
||||
class ElementA_,
|
||||
class GmemLayoutATag_,
|
||||
int AlignmentA,
|
||||
class ElementB_,
|
||||
class GmemLayoutBTag_,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType>
|
||||
struct CollectiveBuilderMixedInput<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA_,
|
||||
GmemLayoutATag_,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
GmemLayoutBTag_,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
|
||||
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
|
||||
// ConvertAndScale and ConvertAndScaleWithZero
|
||||
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
|
||||
// DirectConvert
|
||||
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
|
||||
private:
|
||||
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
|
||||
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
|
||||
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
|
||||
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
|
||||
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
|
||||
// Determine if mixed input types.
|
||||
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
|
||||
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
|
||||
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>;
|
||||
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
|
||||
|
||||
public:
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
|
||||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
|
||||
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||
|
||||
template <class T>
|
||||
static auto get_stride(T const& t) {
|
||||
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
|
||||
return t;
|
||||
} else {
|
||||
if constexpr (cute::is_pointer_v<T>) {
|
||||
return &cute::stride(*t);
|
||||
} else {
|
||||
return cute::stride(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
|
||||
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
|
||||
|
||||
using ElementPairA =
|
||||
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
|
||||
using ElementPairB =
|
||||
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
|
||||
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
|
||||
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(
|
||||
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
|
||||
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
|
||||
// operands.
|
||||
static constexpr bool SwapAB =
|
||||
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
|
||||
static constexpr bool IsWarpSpecializedTransposeB =
|
||||
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
|
||||
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
|
||||
|
||||
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
|
||||
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type.
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
// Handle mixed dtypes and MMA.
|
||||
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
|
||||
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
|
||||
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
|
||||
// Always the same for element B.
|
||||
using RealElementBMma = RealElementB;
|
||||
|
||||
static_assert(
|
||||
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
|
||||
"Mixed input GEMM does not support MN major layout except for 16bit");
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_any_of_v<
|
||||
KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>,
|
||||
Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<
|
||||
RealElementAMma,
|
||||
RealElementBMma,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
GMMA::Major::K,
|
||||
GMMA::Major::K>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorA,
|
||||
ElementAMma,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
|
||||
GmmaMajorB,
|
||||
ElementBMma,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{})),
|
||||
IsWarpSpecializedTransposeB>());
|
||||
|
||||
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
|
||||
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
|
||||
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
|
||||
|
||||
// Handle mixed dtype array GEMM's size of tensor map storage.
|
||||
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
Sm90ReducedSmemCapacityBytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{})
|
||||
: detail::compute_stage_count_or_override_single_affine_transformed_input<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
RealElementA,
|
||||
RealElementB,
|
||||
ElementScale,
|
||||
ElementZero,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{}))
|
||||
: detail::compute_stage_count_or_override<
|
||||
detail::sm90_smem_capacity_bytes,
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
TileShape_MNK,
|
||||
StageCountType::bytes,
|
||||
SmemAlignment>(StageCountType{});
|
||||
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsMixedInput,
|
||||
cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
|
||||
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
|
||||
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
|
||||
|
||||
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||
using StrideA = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
|
||||
GmemLayoutATag_,
|
||||
TagToStrideA_t<GmemLayoutATag>>;
|
||||
using StrideB = cute::conditional_t<
|
||||
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
|
||||
GmemLayoutBTag_,
|
||||
TagToStrideB_t<GmemLayoutBTag>>;
|
||||
|
||||
using CollectiveOp = CollectiveMmaArrayMixedInput<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementPairA,
|
||||
StrideA,
|
||||
ElementPairB,
|
||||
StrideB,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity>;
|
||||
|
||||
static_assert(
|
||||
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ArchTag,
|
||||
class OpClass,
|
||||
class ElementA,
|
||||
class GmemLayoutA,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutB,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType,
|
||||
class Enable = void>
|
||||
struct CollectiveBuilderMixedInput {
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class DispatchPolicy,
|
||||
class TileShape,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class TiledMma,
|
||||
class GmemTiledCopyA,
|
||||
class SmemLayoutAtomA,
|
||||
class SmemCopyAtomA,
|
||||
class TransformA,
|
||||
class GmemTiledCopyB,
|
||||
class SmemLayoutAtomB,
|
||||
class SmemCopyAtomB,
|
||||
class TransformB>
|
||||
struct CollectiveMmaArrayMixedInput {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,62 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args, hw_info, scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
38
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
38
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::
|
||||
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@@ -0,0 +1,197 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename SchedulerType,
|
||||
typename OutType,
|
||||
int GroupSizeM_,
|
||||
int GroupSizeN_,
|
||||
int GroupSizeK_,
|
||||
int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
using GroupSizeK = Int<GroupSizeK_>;
|
||||
using TileSizeM = Int<TileSizeM_>;
|
||||
|
||||
static_assert(TileSizeM_ % GroupSizeM_ == 0, "TileSizeM must be a multiple of GroupSizeM");
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementAB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>; // Threadblock-level tile size
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
SchedulerType>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementA = ElementAB;
|
||||
using ElementB = ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementBlockScale*>(a_scales.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementBlockScale*>(b_scales.data_ptr());
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
using StrideD = typename GemmKernel::StrideD;
|
||||
using StrideC = typename GemmKernel::StrideC;
|
||||
|
||||
StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
||||
|
||||
static constexpr bool UsesStreamKScheduler =
|
||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
using DecompositionMode =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
|
||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||
}
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.device(), {m, n, k, 1}, mainloop_args, epilogue_args, scheduler);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
|
||||
if (k > 3 * n) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/device_kernel.h>
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat {
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args) {
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
} else {
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
|
||||
<< ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes) {
|
||||
if (!workspace) {
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result =
|
||||
cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,492 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/complex.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/fast_math.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor {
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRefA ref_A_,
|
||||
TensorRefB ref_B_,
|
||||
TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_,
|
||||
TensorRefC ref_C_,
|
||||
TensorRefC ref_D_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(GemmUniversalMode::kGemm),
|
||||
problem_size(problem_size_),
|
||||
batch_count(1),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_alpha_col(ref_alpha_col_),
|
||||
ref_alpha_row(ref_alpha_row_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_D(0),
|
||||
epilogue_visitor(epilogue_visitor_) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_alpha_col(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_alpha_col(nullptr),
|
||||
ptr_alpha_row(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0) {}
|
||||
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size),
|
||||
swizzle_log_tile(0),
|
||||
params_A(args.ref_A.layout()),
|
||||
params_B(args.ref_B.layout()),
|
||||
params_alpha_col(args.ref_alpha_col.layout()),
|
||||
params_alpha_row(args.ref_alpha_col.layout()),
|
||||
params_C(args.ref_C.layout()),
|
||||
params_D(args.ref_D.layout()),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(args.problem_size.k()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
ptr_alpha_col(args.ref_alpha_col.data()),
|
||||
ptr_alpha_row(args.ref_alpha_row.data()),
|
||||
ptr_C(args.ref_C.data()),
|
||||
ptr_D(args.ref_D.data()),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
epilogue_visitor(args.epilogue_visitor) {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) {
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage) {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
} else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
} else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
bool with_bias = true;
|
||||
if (params.ptr_C == nullptr) {
|
||||
with_bias = false;
|
||||
}
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
params.params_alpha_col,
|
||||
params.params_C,
|
||||
params.params_D,
|
||||
with_bias,
|
||||
true,
|
||||
true,
|
||||
params.ptr_alpha_row,
|
||||
params.ptr_alpha_col,
|
||||
params.ptr_C,
|
||||
params.ptr_D,
|
||||
threadblock_offset,
|
||||
blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) {
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
|
||||
run_kernel_(params, shared_storage);
|
||||
} else {
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage) {
|
||||
run_kernel<ArchTag>(params, shared_storage);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Reference in New Issue
Block a user